• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 15818802733

23 Jun 2025 08:09AM UTC coverage: 84.992% (+0.008%) from 84.984%
15818802733

Pull #2098

github

web-flow
Merge 97f5f3643 into 84a90dc56
Pull Request #2098: Implement `tensor.isin`

2989 of 3792 branches covered (78.82%)

Branch coverage included in aggregate %.

115 of 120 new or added lines in 7 files covered. (95.83%)

9 existing lines in 4 files now uncovered.

12290 of 14185 relevant lines covered (86.64%)

6927.16 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.46
/dpctl/tensor/_sorting.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2025 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
import operator
1✔
18
from typing import NamedTuple
1✔
19

20
import dpctl.tensor as dpt
1✔
21
import dpctl.tensor._tensor_impl as ti
1✔
22
import dpctl.utils as du
1✔
23

24
from ._numpy_helper import normalize_axis_index
1✔
25
from ._tensor_sorting_impl import (
1✔
26
    _argsort_ascending,
27
    _argsort_descending,
28
    _radix_argsort_ascending,
29
    _radix_argsort_descending,
30
    _radix_sort_ascending,
31
    _radix_sort_descending,
32
    _radix_sort_dtype_supported,
33
    _sort_ascending,
34
    _sort_descending,
35
    _topk,
36
)
37

38
__all__ = ["sort", "argsort"]
1✔
39

40

41
def _get_mergesort_impl_fn(descending):
1✔
42
    return _sort_descending if descending else _sort_ascending
1✔
43

44

45
def _get_radixsort_impl_fn(descending):
1✔
46
    return _radix_sort_descending if descending else _radix_sort_ascending
1✔
47

48

49
def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
1✔
50
    """sort(x, axis=-1, descending=False, stable=True)
51

52
    Returns a sorted copy of an input array `x`.
53

54
    Args:
55
        x (usm_ndarray):
56
            input array.
57
        axis (Optional[int]):
58
            axis along which to sort. If set to `-1`, the function
59
            must sort along the last axis. Default: `-1`.
60
        descending (Optional[bool]):
61
            sort order. If `True`, the array must be sorted in descending
62
            order (by value). If `False`, the array must be sorted in
63
            ascending order (by value). Default: `False`.
64
        stable (Optional[bool]):
65
            sort stability. If `True`, the returned array must maintain the
66
            relative order of `x` values which compare as equal. If `False`,
67
            the returned array may or may not maintain the relative order of
68
            `x` values which compare as equal. Default: `True`.
69
        kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
70
            Sorting algorithm. The default is `"stable"`, which uses parallel
71
            merge-sort or parallel radix-sort algorithms depending on the
72
            array data type.
73
    Returns:
74
        usm_ndarray:
75
            a sorted array. The returned array has the same data type and
76
            the same shape as the input array `x`.
77
    """
78
    if not isinstance(x, dpt.usm_ndarray):
1✔
79
        raise TypeError(
1✔
80
            f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
81
        )
82
    nd = x.ndim
1✔
83
    if nd == 0:
1✔
84
        axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
1✔
85
        return dpt.copy(x, order="C")
1✔
86
    else:
87
        axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
1✔
88
    if x.size == 1:
1✔
89
        return dpt.copy(x, order="C")
1✔
90
    a1 = axis + 1
1✔
91
    if a1 == nd:
1✔
92
        perm = list(range(nd))
1✔
93
        arr = x
1✔
94
    else:
95
        perm = [i for i in range(nd) if i != axis] + [
1✔
96
            axis,
97
        ]
98
        arr = dpt.permute_dims(x, perm)
1✔
99
    if kind is None:
1✔
100
        kind = "stable"
1✔
101
    if not isinstance(kind, str) or kind not in [
1✔
102
        "stable",
103
        "radixsort",
104
        "mergesort",
105
    ]:
106
        raise ValueError(
1✔
107
            "Unsupported kind value. Expected 'stable', 'mergesort', "
108
            f"or 'radixsort', but got '{kind}'"
109
        )
110
    if kind == "mergesort":
1✔
111
        impl_fn = _get_mergesort_impl_fn(descending)
1✔
112
    elif kind == "radixsort":
1✔
113
        if _radix_sort_dtype_supported(x.dtype.num):
1!
114
            impl_fn = _get_radixsort_impl_fn(descending)
1✔
115
        else:
116
            raise ValueError(f"Radix sort is not supported for {x.dtype}")
×
117
    else:
118
        dt = x.dtype
1✔
119
        if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
1✔
120
            impl_fn = _get_radixsort_impl_fn(descending)
1✔
121
        else:
122
            impl_fn = _get_mergesort_impl_fn(descending)
1✔
123
    exec_q = x.sycl_queue
1✔
124
    _manager = du.SequentialOrderManager[exec_q]
1✔
125
    dep_evs = _manager.submitted_events
1✔
126
    if arr.flags.c_contiguous:
1✔
127
        res = dpt.empty_like(arr, order="C")
1✔
128
        ht_ev, impl_ev = impl_fn(
1✔
129
            src=arr,
130
            trailing_dims_to_sort=1,
131
            dst=res,
132
            sycl_queue=exec_q,
133
            depends=dep_evs,
134
        )
135
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
136
    else:
137
        tmp = dpt.empty_like(arr, order="C")
1✔
138
        ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
139
            src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
140
        )
141
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
142
        res = dpt.empty_like(arr, order="C")
1✔
143
        ht_ev, impl_ev = impl_fn(
1✔
144
            src=tmp,
145
            trailing_dims_to_sort=1,
146
            dst=res,
147
            sycl_queue=exec_q,
148
            depends=[copy_ev],
149
        )
150
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
151
    if a1 != nd:
1✔
152
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
1✔
153
        res = dpt.permute_dims(res, inv_perm)
1✔
154
    return res
1✔
155

156

157
def _get_mergeargsort_impl_fn(descending):
1✔
158
    return _argsort_descending if descending else _argsort_ascending
1✔
159

160

161
def _get_radixargsort_impl_fn(descending):
1✔
162
    return _radix_argsort_descending if descending else _radix_argsort_ascending
1✔
163

164

165
def argsort(x, axis=-1, descending=False, stable=True, kind=None):
1✔
166
    """argsort(x, axis=-1, descending=False, stable=True)
167

168
    Returns the indices that sort an array `x` along a specified axis.
169

170
    Args:
171
        x (usm_ndarray):
172
            input array.
173
        axis (Optional[int]):
174
            axis along which to sort. If set to `-1`, the function
175
            must sort along the last axis. Default: `-1`.
176
        descending (Optional[bool]):
177
            sort order. If `True`, the array must be sorted in descending
178
            order (by value). If `False`, the array must be sorted in
179
            ascending order (by value). Default: `False`.
180
        stable (Optional[bool]):
181
            sort stability. If `True`, the returned array must maintain the
182
            relative order of `x` values which compare as equal. If `False`,
183
            the returned array may or may not maintain the relative order of
184
            `x` values which compare as equal. Default: `True`.
185
        kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
186
            Sorting algorithm. The default is `"stable"`, which uses parallel
187
            merge-sort or parallel radix-sort algorithms depending on the
188
            array data type.
189

190
    Returns:
191
        usm_ndarray:
192
            an array of indices. The returned array has the  same shape as
193
            the input array `x`. The return array has default array index
194
            data type.
195
    """
196
    if not isinstance(x, dpt.usm_ndarray):
1✔
197
        raise TypeError(
1✔
198
            f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
199
        )
200
    nd = x.ndim
1✔
201
    if nd == 0:
1✔
202
        axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
1✔
203
        return dpt.zeros_like(
1✔
204
            x, dtype=ti.default_device_index_type(x.sycl_queue), order="C"
205
        )
206
    else:
207
        axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
1✔
208
    if x.size == 1:
1!
NEW
209
        return dpt.zeros_like(
×
210
            x, dtype=ti.default_device_index_type(x.sycl_queue), order="C"
211
        )
212
    a1 = axis + 1
1✔
213
    if a1 == nd:
1✔
214
        perm = list(range(nd))
1✔
215
        arr = x
1✔
216
    else:
217
        perm = [i for i in range(nd) if i != axis] + [
1✔
218
            axis,
219
        ]
220
        arr = dpt.permute_dims(x, perm)
1✔
221
    if kind is None:
1✔
222
        kind = "stable"
1✔
223
    if not isinstance(kind, str) or kind not in [
1✔
224
        "stable",
225
        "radixsort",
226
        "mergesort",
227
    ]:
228
        raise ValueError(
1✔
229
            "Unsupported kind value. Expected 'stable', 'mergesort', "
230
            f"or 'radixsort', but got '{kind}'"
231
        )
232
    if kind == "mergesort":
1✔
233
        impl_fn = _get_mergeargsort_impl_fn(descending)
1✔
234
    elif kind == "radixsort":
1✔
235
        if _radix_sort_dtype_supported(x.dtype.num):
1!
236
            impl_fn = _get_radixargsort_impl_fn(descending)
1✔
237
        else:
238
            raise ValueError(f"Radix sort is not supported for {x.dtype}")
×
239
    else:
240
        dt = x.dtype
1✔
241
        if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
1✔
242
            impl_fn = _get_radixargsort_impl_fn(descending)
1✔
243
        else:
244
            impl_fn = _get_mergeargsort_impl_fn(descending)
1✔
245
    exec_q = x.sycl_queue
1✔
246
    _manager = du.SequentialOrderManager[exec_q]
1✔
247
    dep_evs = _manager.submitted_events
1✔
248
    index_dt = ti.default_device_index_type(exec_q)
1✔
249
    if arr.flags.c_contiguous:
1✔
250
        res = dpt.empty_like(arr, dtype=index_dt, order="C")
1✔
251
        ht_ev, impl_ev = impl_fn(
1✔
252
            src=arr,
253
            trailing_dims_to_sort=1,
254
            dst=res,
255
            sycl_queue=exec_q,
256
            depends=dep_evs,
257
        )
258
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
259
    else:
260
        tmp = dpt.empty_like(arr, order="C")
1✔
261
        ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
262
            src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
263
        )
264
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
265
        res = dpt.empty_like(arr, dtype=index_dt, order="C")
1✔
266
        ht_ev, impl_ev = impl_fn(
1✔
267
            src=tmp,
268
            trailing_dims_to_sort=1,
269
            dst=res,
270
            sycl_queue=exec_q,
271
            depends=[copy_ev],
272
        )
273
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
274
    if a1 != nd:
1✔
275
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
1✔
276
        res = dpt.permute_dims(res, inv_perm)
1✔
277
    return res
1✔
278

279

280
def _get_top_k_largest(mode):
1✔
281
    modes = {"largest": True, "smallest": False}
1✔
282
    try:
1✔
283
        return modes[mode]
1✔
284
    except KeyError:
1✔
285
        raise ValueError(
1✔
286
            f"`mode` must be `largest` or `smallest`. Got `{mode}`."
287
        )
288

289

290
class TopKResult(NamedTuple):
1✔
291
    values: dpt.usm_ndarray
1✔
292
    indices: dpt.usm_ndarray
1✔
293

294

295
def top_k(x, k, /, *, axis=None, mode="largest"):
1✔
296
    """top_k(x, k, axis=None, mode="largest")
297

298
    Returns the `k` largest or smallest values and their indices in the input
299
    array `x` along the specified axis `axis`.
300

301
    Args:
302
        x (usm_ndarray):
303
            input array.
304
        k (int):
305
            number of elements to find. Must be a positive integer value.
306
        axis (Optional[int]):
307
            axis along which to search. If `None`, the search will be performed
308
            over the flattened array. Default: ``None``.
309
        mode (Literal["largest", "smallest"]):
310
            search mode. Must be one of the following modes:
311

312
            - `"largest"`: return the `k` largest elements.
313
            - `"smallest"`: return the `k` smallest elements.
314

315
            Default: `"largest"`.
316

317
    Returns:
318
        tuple[usm_ndarray, usm_ndarray]
319
            a namedtuple `(values, indices)` whose
320

321
            * first element `values` will be an array containing the `k`
322
              largest or smallest elements of `x`. The array has the same data
323
              type as `x`. If `axis` was `None`, `values` will be a
324
              one-dimensional array with shape `(k,)` and otherwise, `values`
325
              will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]`
326
            * second element `indices` will be an array containing indices of
327
              `x` that result in `values`. The array will have the same shape
328
              as `values` and will have the default array index data type.
329
    """
330
    largest = _get_top_k_largest(mode)
1✔
331
    if not isinstance(x, dpt.usm_ndarray):
1✔
332
        raise TypeError(
1✔
333
            f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
334
        )
335

336
    k = operator.index(k)
1✔
337
    if k < 0:
1✔
338
        raise ValueError("`k` must be a positive integer value")
1✔
339

340
    nd = x.ndim
1✔
341
    if axis is None:
1✔
342
        sz = x.size
1✔
343
        if nd == 0:
1✔
344
            if k > 1:
1✔
345
                raise ValueError(f"`k`={k} is out of bounds 1")
1✔
346
            return TopKResult(
1✔
347
                dpt.copy(x, order="C"),
348
                dpt.zeros_like(
349
                    x, dtype=ti.default_device_index_type(x.sycl_queue)
350
                ),
351
            )
352
        arr = x
1✔
353
        n_search_dims = None
1✔
354
        res_sh = k
1✔
355
    else:
356
        axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
1✔
357
        sz = x.shape[axis]
1✔
358
        a1 = axis + 1
1✔
359
        if a1 == nd:
1✔
360
            perm = list(range(nd))
1✔
361
            arr = x
1✔
362
        else:
363
            perm = [i for i in range(nd) if i != axis] + [
1✔
364
                axis,
365
            ]
366
            arr = dpt.permute_dims(x, perm)
1✔
367
        n_search_dims = 1
1✔
368
        res_sh = arr.shape[: nd - 1] + (k,)
1✔
369

370
    if k > sz:
1✔
371
        raise ValueError(f"`k`={k} is out of bounds {sz}")
1✔
372

373
    exec_q = x.sycl_queue
1✔
374
    _manager = du.SequentialOrderManager[exec_q]
1✔
375
    dep_evs = _manager.submitted_events
1✔
376

377
    res_usm_type = arr.usm_type
1✔
378
    if arr.flags.c_contiguous:
1✔
379
        vals = dpt.empty(
1✔
380
            res_sh,
381
            dtype=arr.dtype,
382
            usm_type=res_usm_type,
383
            order="C",
384
            sycl_queue=exec_q,
385
        )
386
        inds = dpt.empty(
1✔
387
            res_sh,
388
            dtype=ti.default_device_index_type(exec_q),
389
            usm_type=res_usm_type,
390
            order="C",
391
            sycl_queue=exec_q,
392
        )
393
        ht_ev, impl_ev = _topk(
1✔
394
            src=arr,
395
            trailing_dims_to_search=n_search_dims,
396
            k=k,
397
            largest=largest,
398
            vals=vals,
399
            inds=inds,
400
            sycl_queue=exec_q,
401
            depends=dep_evs,
402
        )
403
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
404
    else:
405
        tmp = dpt.empty_like(arr, order="C")
1✔
406
        ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
407
            src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
408
        )
409
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
410
        vals = dpt.empty(
1✔
411
            res_sh,
412
            dtype=arr.dtype,
413
            usm_type=res_usm_type,
414
            order="C",
415
            sycl_queue=exec_q,
416
        )
417
        inds = dpt.empty(
1✔
418
            res_sh,
419
            dtype=ti.default_device_index_type(exec_q),
420
            usm_type=res_usm_type,
421
            order="C",
422
            sycl_queue=exec_q,
423
        )
424
        ht_ev, impl_ev = _topk(
1✔
425
            src=tmp,
426
            trailing_dims_to_search=n_search_dims,
427
            k=k,
428
            largest=largest,
429
            vals=vals,
430
            inds=inds,
431
            sycl_queue=exec_q,
432
            depends=[copy_ev],
433
        )
434
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
435
    if axis is not None and a1 != nd:
1✔
436
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
1✔
437
        vals = dpt.permute_dims(vals, inv_perm)
1✔
438
        inds = dpt.permute_dims(inds, inv_perm)
1✔
439

440
    return TopKResult(vals, inds)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc