• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 12467546740

23 Dec 2024 01:24PM UTC coverage: 87.488% (-0.2%) from 87.659%
12467546740

Pull #1947

github

web-flow
Merge 3e5e3032b into 678b4cfd7
Pull Request #1947: [DO NOT MERGE] Sasha triage topk test failure amd

3127 of 3658 branches covered (85.48%)

Branch coverage included in aggregate %.

34 of 62 new or added lines in 2 files covered. (54.84%)

30 existing lines in 1 file now uncovered.

11843 of 13453 relevant lines covered (88.03%)

7100.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.22
/dpctl/tensor/_sorting.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2024 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
import operator
1✔
18
from typing import NamedTuple
1✔
19

20
import dpctl.tensor as dpt
1✔
21
import dpctl.tensor._tensor_impl as ti
1✔
22
import dpctl.utils as du
1✔
23

24
from ._numpy_helper import normalize_axis_index
1✔
25
from ._tensor_sorting_impl import (
1✔
26
    _argsort_ascending,
27
    _argsort_descending,
28
    _sort_ascending,
29
    _sort_descending,
30
    _topk,
31
)
32
from ._tensor_sorting_radix_impl import (
1✔
33
    _radix_argsort_ascending,
34
    _radix_argsort_descending,
35
    _radix_sort_ascending,
36
    _radix_sort_descending,
37
    _radix_sort_dtype_supported,
38
)
39

40
__all__ = ["sort", "argsort"]
1✔
41

42

43
def _get_mergesort_impl_fn(descending):
1✔
44
    return _sort_descending if descending else _sort_ascending
1✔
45

46

47
def _get_radixsort_impl_fn(descending):
1✔
48
    return _radix_sort_descending if descending else _radix_sort_ascending
1✔
49

50

51
def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
1✔
52
    """sort(x, axis=-1, descending=False, stable=True)
53

54
    Returns a sorted copy of an input array `x`.
55

56
    Args:
57
        x (usm_ndarray):
58
            input array.
59
        axis (Optional[int]):
60
            axis along which to sort. If set to `-1`, the function
61
            must sort along the last axis. Default: `-1`.
62
        descending (Optional[bool]):
63
            sort order. If `True`, the array must be sorted in descending
64
            order (by value). If `False`, the array must be sorted in
65
            ascending order (by value). Default: `False`.
66
        stable (Optional[bool]):
67
            sort stability. If `True`, the returned array must maintain the
68
            relative order of `x` values which compare as equal. If `False`,
69
            the returned array may or may not maintain the relative order of
70
            `x` values which compare as equal. Default: `True`.
71
        kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
72
            Sorting algorithm. The default is `"stable"`, which uses parallel
73
            merge-sort or parallel radix-sort algorithms depending on the
74
            array data type.
75
    Returns:
76
        usm_ndarray:
77
            a sorted array. The returned array has the same data type and
78
            the same shape as the input array `x`.
79
    """
80
    if not isinstance(x, dpt.usm_ndarray):
1✔
81
        raise TypeError(
1✔
82
            f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
83
        )
84
    nd = x.ndim
1✔
85
    if nd == 0:
1✔
86
        axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
1✔
87
        return dpt.copy(x, order="C")
1✔
88
    else:
89
        axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
1✔
90
    a1 = axis + 1
1✔
91
    if a1 == nd:
1✔
92
        perm = list(range(nd))
1✔
93
        arr = x
1✔
94
    else:
95
        perm = [i for i in range(nd) if i != axis] + [
1✔
96
            axis,
97
        ]
98
        arr = dpt.permute_dims(x, perm)
1✔
99
    if kind is None:
1✔
100
        kind = "stable"
1✔
101
    if not isinstance(kind, str) or kind not in [
1✔
102
        "stable",
103
        "radixsort",
104
        "mergesort",
105
    ]:
106
        raise ValueError(
1✔
107
            "Unsupported kind value. Expected 'stable', 'mergesort', "
108
            f"or 'radixsort', but got '{kind}'"
109
        )
110
    if kind == "mergesort":
1✔
111
        impl_fn = _get_mergesort_impl_fn(descending)
1✔
112
    elif kind == "radixsort":
1✔
113
        if _radix_sort_dtype_supported(x.dtype.num):
1!
114
            impl_fn = _get_radixsort_impl_fn(descending)
1✔
115
        else:
UNCOV
116
            raise ValueError(f"Radix sort is not supported for {x.dtype}")
×
117
    else:
118
        dt = x.dtype
1✔
119
        if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
1✔
120
            impl_fn = _get_radixsort_impl_fn(descending)
1✔
121
        else:
122
            impl_fn = _get_mergesort_impl_fn(descending)
1✔
123
    exec_q = x.sycl_queue
1✔
124
    _manager = du.SequentialOrderManager[exec_q]
1✔
125
    dep_evs = _manager.submitted_events
1✔
126
    if arr.flags.c_contiguous:
1✔
127
        res = dpt.empty_like(arr, order="C")
1✔
128
        ht_ev, impl_ev = impl_fn(
1✔
129
            src=arr,
130
            trailing_dims_to_sort=1,
131
            dst=res,
132
            sycl_queue=exec_q,
133
            depends=dep_evs,
134
        )
135
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
136
    else:
137
        tmp = dpt.empty_like(arr, order="C")
1✔
138
        ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
139
            src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
140
        )
141
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
142
        res = dpt.empty_like(arr, order="C")
1✔
143
        ht_ev, impl_ev = impl_fn(
1✔
144
            src=tmp,
145
            trailing_dims_to_sort=1,
146
            dst=res,
147
            sycl_queue=exec_q,
148
            depends=[copy_ev],
149
        )
150
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
151
    if a1 != nd:
1✔
152
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
1✔
153
        res = dpt.permute_dims(res, inv_perm)
1✔
154
    return res
1✔
155

156

157
def _get_mergeargsort_impl_fn(descending):
1✔
158
    return _argsort_descending if descending else _argsort_ascending
1✔
159

160

161
def _get_radixargsort_impl_fn(descending):
1✔
162
    return _radix_argsort_descending if descending else _radix_argsort_ascending
1✔
163

164

165
def argsort(x, axis=-1, descending=False, stable=True, kind=None):
1✔
166
    """argsort(x, axis=-1, descending=False, stable=True)
167

168
    Returns the indices that sort an array `x` along a specified axis.
169

170
    Args:
171
        x (usm_ndarray):
172
            input array.
173
        axis (Optional[int]):
174
            axis along which to sort. If set to `-1`, the function
175
            must sort along the last axis. Default: `-1`.
176
        descending (Optional[bool]):
177
            sort order. If `True`, the array must be sorted in descending
178
            order (by value). If `False`, the array must be sorted in
179
            ascending order (by value). Default: `False`.
180
        stable (Optional[bool]):
181
            sort stability. If `True`, the returned array must maintain the
182
            relative order of `x` values which compare as equal. If `False`,
183
            the returned array may or may not maintain the relative order of
184
            `x` values which compare as equal. Default: `True`.
185
        kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
186
            Sorting algorithm. The default is `"stable"`, which uses parallel
187
            merge-sort or parallel radix-sort algorithms depending on the
188
            array data type.
189

190
    Returns:
191
        usm_ndarray:
192
            an array of indices. The returned array has the  same shape as
193
            the input array `x`. The return array has default array index
194
            data type.
195
    """
196
    if not isinstance(x, dpt.usm_ndarray):
1✔
197
        raise TypeError(
1✔
198
            f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
199
        )
200
    nd = x.ndim
1✔
201
    if nd == 0:
1✔
202
        axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
1✔
203
        return dpt.zeros_like(
1✔
204
            x, dtype=ti.default_device_index_type(x.sycl_queue), order="C"
205
        )
206
    else:
207
        axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
1✔
208
    a1 = axis + 1
1✔
209
    if a1 == nd:
1✔
210
        perm = list(range(nd))
1✔
211
        arr = x
1✔
212
    else:
213
        perm = [i for i in range(nd) if i != axis] + [
1✔
214
            axis,
215
        ]
216
        arr = dpt.permute_dims(x, perm)
1✔
217
    if kind is None:
1✔
218
        kind = "stable"
1✔
219
    if not isinstance(kind, str) or kind not in [
1✔
220
        "stable",
221
        "radixsort",
222
        "mergesort",
223
    ]:
224
        raise ValueError(
1✔
225
            "Unsupported kind value. Expected 'stable', 'mergesort', "
226
            f"or 'radixsort', but got '{kind}'"
227
        )
228
    if kind == "mergesort":
1✔
229
        impl_fn = _get_mergeargsort_impl_fn(descending)
1✔
230
    elif kind == "radixsort":
1✔
231
        if _radix_sort_dtype_supported(x.dtype.num):
1!
232
            impl_fn = _get_radixargsort_impl_fn(descending)
1✔
233
        else:
UNCOV
234
            raise ValueError(f"Radix sort is not supported for {x.dtype}")
×
235
    else:
236
        dt = x.dtype
1✔
237
        if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
1✔
238
            impl_fn = _get_radixargsort_impl_fn(descending)
1✔
239
        else:
240
            impl_fn = _get_mergeargsort_impl_fn(descending)
1✔
241
    exec_q = x.sycl_queue
1✔
242
    _manager = du.SequentialOrderManager[exec_q]
1✔
243
    dep_evs = _manager.submitted_events
1✔
244
    index_dt = ti.default_device_index_type(exec_q)
1✔
245
    if arr.flags.c_contiguous:
1✔
246
        res = dpt.empty_like(arr, dtype=index_dt, order="C")
1✔
247
        ht_ev, impl_ev = impl_fn(
1✔
248
            src=arr,
249
            trailing_dims_to_sort=1,
250
            dst=res,
251
            sycl_queue=exec_q,
252
            depends=dep_evs,
253
        )
254
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
255
    else:
256
        tmp = dpt.empty_like(arr, order="C")
1✔
257
        ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
258
            src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
259
        )
260
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
261
        res = dpt.empty_like(arr, dtype=index_dt, order="C")
1✔
262
        ht_ev, impl_ev = impl_fn(
1✔
263
            src=tmp,
264
            trailing_dims_to_sort=1,
265
            dst=res,
266
            sycl_queue=exec_q,
267
            depends=[copy_ev],
268
        )
269
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
270
    if a1 != nd:
1✔
271
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
1✔
272
        res = dpt.permute_dims(res, inv_perm)
1✔
273
    return res
1✔
274

275

276
def _get_top_k_largest(mode):
1✔
277
    modes = {"largest": True, "smallest": False}
1✔
278
    try:
1✔
279
        return modes[mode]
1✔
NEW
UNCOV
280
    except KeyError:
×
NEW
UNCOV
281
        raise ValueError(
×
282
            f"`mode` must be `largest` or `smallest`. Got `{mode}`."
283
        )
284

285

286
class TopKResult(NamedTuple):
1✔
287
    values: dpt.usm_ndarray
1✔
288
    indices: dpt.usm_ndarray
1✔
289

290

291
def top_k(x, k, /, *, axis=None, mode="largest"):
1✔
292
    """top_k(x, k, axis=None, mode="largest")
293

294
    Returns the `k` largest or smallest values and their indices in the input
295
    array `x` along the specified axis `axis`.
296

297
    Args:
298
        x (usm_ndarray):
299
            input array.
300
        k (int):
301
            number of elements to find. Must be a positive integer value.
302
        axis (Optional[int]):
303
            axis along which to search. If `None`, the search will be performed
304
            over the flattened array. Default: ``None``.
305
        mode (Literal["largest", "smallest"]):
306
            search mode. Must be one of the following modes:
307

308
            - `"largest"`: return the `k` largest elements.
309
            - `"smallest"`: return the `k` smallest elements.
310

311
            Default: `"largest"`.
312

313
    Returns:
314
        tuple[usm_ndarray, usm_ndarray]
315
            a namedtuple `(values, indices)` whose
316

317
            * first element `values` will be an array containing the `k`
318
              largest or smallest elements of `x`. The array has the same data
319
              type as `x`. If `axis` was `None`, `values` will be a
320
              one-dimensional array with shape `(k,)` and otherwise, `values`
321
              will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]`
322
            * second element `indices` will be an array containing indices of
323
              `x` that result in `values`. The array will have the same shape
324
              as `values` and will have the default array index data type.
325
    """
326
    largest = _get_top_k_largest(mode)
1✔
327
    if not isinstance(x, dpt.usm_ndarray):
1!
NEW
UNCOV
328
        raise TypeError(
×
329
            f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
330
        )
331

332
    k = operator.index(k)
1✔
333
    if k < 0:
1!
NEW
UNCOV
334
        raise ValueError("`k` must be a positive integer value")
×
335

336
    nd = x.ndim
1✔
337
    if axis is None:
1!
338
        sz = x.size
1✔
339
        if nd == 0:
1!
NEW
UNCOV
340
            if k > 1:
×
NEW
UNCOV
341
                raise ValueError(f"`k`={k} is out of bounds 1")
×
NEW
UNCOV
342
            return TopKResult(
×
343
                dpt.copy(x, order="C"),
344
                dpt.zeros_like(
345
                    x, dtype=ti.default_device_index_type(x.sycl_queue)
346
                ),
347
            )
348
        arr = x
1✔
349
        n_search_dims = None
1✔
350
        res_sh = k
1✔
351
    else:
NEW
UNCOV
352
        axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
×
NEW
UNCOV
353
        sz = x.shape[axis]
×
NEW
UNCOV
354
        a1 = axis + 1
×
NEW
UNCOV
355
        if a1 == nd:
×
NEW
UNCOV
356
            perm = list(range(nd))
×
NEW
UNCOV
357
            arr = x
×
358
        else:
NEW
UNCOV
359
            perm = [i for i in range(nd) if i != axis] + [
×
360
                axis,
361
            ]
NEW
UNCOV
362
            arr = dpt.permute_dims(x, perm)
×
NEW
UNCOV
363
        n_search_dims = 1
×
NEW
UNCOV
364
        res_sh = arr.shape[: nd - 1] + (k,)
×
365

366
    if k > sz:
1!
NEW
UNCOV
367
        raise ValueError(f"`k`={k} is out of bounds {sz}")
×
368

369
    exec_q = x.sycl_queue
1✔
370
    _manager = du.SequentialOrderManager[exec_q]
1✔
371
    dep_evs = _manager.submitted_events
1✔
372

373
    res_usm_type = arr.usm_type
1✔
374
    if arr.flags.c_contiguous:
1!
375
        vals = dpt.empty(
1✔
376
            res_sh,
377
            dtype=arr.dtype,
378
            usm_type=res_usm_type,
379
            order="C",
380
            sycl_queue=exec_q,
381
        )
382
        inds = dpt.empty(
1✔
383
            res_sh,
384
            dtype=ti.default_device_index_type(exec_q),
385
            usm_type=res_usm_type,
386
            order="C",
387
            sycl_queue=exec_q,
388
        )
389
        ht_ev, impl_ev = _topk(
1✔
390
            src=arr,
391
            trailing_dims_to_search=n_search_dims,
392
            k=k,
393
            largest=largest,
394
            vals=vals,
395
            inds=inds,
396
            sycl_queue=exec_q,
397
            depends=dep_evs,
398
        )
399
        _manager.add_event_pair(ht_ev, impl_ev)
1✔
400
    else:
NEW
UNCOV
401
        tmp = dpt.empty_like(arr, order="C")
×
NEW
UNCOV
402
        ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
×
403
            src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
404
        )
NEW
UNCOV
405
        _manager.add_event_pair(ht_ev, copy_ev)
×
NEW
UNCOV
406
        vals = dpt.empty(
×
407
            res_sh,
408
            dtype=arr.dtype,
409
            usm_type=res_usm_type,
410
            order="C",
411
            sycl_queue=exec_q,
412
        )
NEW
UNCOV
413
        inds = dpt.empty(
×
414
            res_sh,
415
            dtype=ti.default_device_index_type(exec_q),
416
            usm_type=res_usm_type,
417
            order="C",
418
            sycl_queue=exec_q,
419
        )
NEW
UNCOV
420
        ht_ev, impl_ev = _topk(
×
421
            src=tmp,
422
            trailing_dims_to_search=n_search_dims,
423
            k=k,
424
            largest=largest,
425
            vals=vals,
426
            inds=inds,
427
            sycl_queue=exec_q,
428
            depends=[copy_ev],
429
        )
NEW
UNCOV
430
        _manager.add_event_pair(ht_ev, impl_ev)
×
431
    if axis is not None and a1 != nd:
1!
NEW
UNCOV
432
        inv_perm = sorted(range(nd), key=lambda d: perm[d])
×
NEW
UNCOV
433
        vals = dpt.permute_dims(vals, inv_perm)
×
NEW
UNCOV
434
        inds = dpt.permute_dims(inds, inv_perm)
×
435

436
    return TopKResult(vals, inds)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc