• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 16460688165

23 Jul 2025 03:19AM UTC coverage: 85.936% (+0.05%) from 85.882%
16460688165

Pull #2098

github

web-flow
Merge 1fc9a2587 into aa05645a6
Pull Request #2098: Implement `tensor.isin`

3251 of 3906 branches covered (83.23%)

Branch coverage included in aggregate %.

116 of 120 new or added lines in 7 files covered. (96.67%)

3 existing lines in 3 files now uncovered.

12300 of 14190 relevant lines covered (86.68%)

5902.08 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.36
/dpctl/tensor/_set_functions.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2025 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
from typing import NamedTuple, Optional, Union
1✔
18

19
import dpctl.tensor as dpt
1✔
20
import dpctl.utils as du
1✔
21

22
from ._copy_utils import _empty_like_orderK
1✔
23
from ._scalar_utils import (
1✔
24
    _get_dtype,
25
    _get_queue_usm_type,
26
    _get_shape,
27
    _validate_dtype,
28
)
29
from ._tensor_elementwise_impl import _not_equal, _subtract
1✔
30
from ._tensor_impl import (
1✔
31
    _copy_usm_ndarray_into_usm_ndarray,
32
    _extract,
33
    _full_usm_ndarray,
34
    _linspace_step,
35
    _take,
36
    default_device_index_type,
37
    mask_positions,
38
)
39
from ._tensor_sorting_impl import (
1✔
40
    _argsort_ascending,
41
    _isin,
42
    _searchsorted_left,
43
    _sort_ascending,
44
)
45
from ._type_utils import (
1✔
46
    _resolve_weak_types_all_py_ints,
47
    _to_device_supported_dtype,
48
)
49

50
__all__ = [
1✔
51
    "isin",
52
    "unique_values",
53
    "unique_counts",
54
    "unique_inverse",
55
    "unique_all",
56
    "UniqueAllResult",
57
    "UniqueCountsResult",
58
    "UniqueInverseResult",
59
]
60

61

62
class UniqueAllResult(NamedTuple):
1✔
63
    values: dpt.usm_ndarray
1✔
64
    indices: dpt.usm_ndarray
1✔
65
    inverse_indices: dpt.usm_ndarray
1✔
66
    counts: dpt.usm_ndarray
1✔
67

68

69
class UniqueCountsResult(NamedTuple):
1✔
70
    values: dpt.usm_ndarray
1✔
71
    counts: dpt.usm_ndarray
1✔
72

73

74
class UniqueInverseResult(NamedTuple):
1✔
75
    values: dpt.usm_ndarray
1✔
76
    inverse_indices: dpt.usm_ndarray
1✔
77

78

79
def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray:
1✔
80
    """unique_values(x)
81

82
    Returns the unique elements of an input array `x`.
83

84
    Args:
85
        x (usm_ndarray):
86
            input array. Inputs with more than one dimension are flattened.
87
    Returns:
88
        usm_ndarray
89
            an array containing the set of unique elements in `x`. The
90
            returned array has the same data type as `x`.
91
    """
92
    if not isinstance(x, dpt.usm_ndarray):
1!
93
        raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
×
94
    array_api_dev = x.device
1✔
95
    exec_q = array_api_dev.sycl_queue
1✔
96
    if x.ndim == 1:
1✔
97
        fx = x
1✔
98
    else:
99
        fx = dpt.reshape(x, (x.size,), order="C")
1✔
100
    if fx.size == 0:
1✔
101
        return fx
1✔
102
    s = dpt.empty_like(fx, order="C")
1✔
103
    _manager = du.SequentialOrderManager[exec_q]
1✔
104
    dep_evs = _manager.submitted_events
1✔
105
    if fx.flags.c_contiguous:
1✔
106
        ht_ev, sort_ev = _sort_ascending(
1✔
107
            src=fx,
108
            trailing_dims_to_sort=1,
109
            dst=s,
110
            sycl_queue=exec_q,
111
            depends=dep_evs,
112
        )
113
        _manager.add_event_pair(ht_ev, sort_ev)
1✔
114
    else:
115
        tmp = dpt.empty_like(fx, order="C")
1✔
116
        ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray(
1✔
117
            src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs
118
        )
119
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
120
        ht_ev, sort_ev = _sort_ascending(
1✔
121
            src=tmp,
122
            trailing_dims_to_sort=1,
123
            dst=s,
124
            sycl_queue=exec_q,
125
            depends=[copy_ev],
126
        )
127
        _manager.add_event_pair(ht_ev, sort_ev)
1✔
128
    unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q)
1✔
129
    ht_ev, uneq_ev = _not_equal(
1✔
130
        src1=s[:-1],
131
        src2=s[1:],
132
        dst=unique_mask[1:],
133
        sycl_queue=exec_q,
134
        depends=[sort_ev],
135
    )
136
    _manager.add_event_pair(ht_ev, uneq_ev)
1✔
137
    # writing into new allocation, no dependencies
138
    ht_ev, one_ev = _full_usm_ndarray(
1✔
139
        fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
140
    )
141
    _manager.add_event_pair(ht_ev, one_ev)
1✔
142
    cumsum = dpt.empty(s.shape, dtype=dpt.int64, sycl_queue=exec_q)
1✔
143
    # synchronizing call
144
    n_uniques = mask_positions(
1✔
145
        unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev]
146
    )
147
    if n_uniques == fx.size:
1✔
148
        return s
1✔
149
    unique_vals = dpt.empty(
1✔
150
        n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
151
    )
152
    ht_ev, ex_e = _extract(
1✔
153
        src=s,
154
        cumsum=cumsum,
155
        axis_start=0,
156
        axis_end=1,
157
        dst=unique_vals,
158
        sycl_queue=exec_q,
159
    )
160
    _manager.add_event_pair(ht_ev, ex_e)
1✔
161
    return unique_vals
1✔
162

163

164
def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
1✔
165
    """unique_counts(x)
166

167
    Returns the unique elements of an input array `x` and the corresponding
168
    counts for each unique element in `x`.
169

170
    Args:
171
        x (usm_ndarray):
172
            input array. Inputs with more than one dimension are flattened.
173
    Returns:
174
        tuple[usm_ndarray, usm_ndarray]
175
            a namedtuple `(values, counts)` whose
176

177
            * first element is the field name `values` and is an array
178
               containing the unique elements of `x`. This array has the
179
               same data type as `x`.
180
            * second element has the field name `counts` and is an array
181
              containing the number of times each unique element occurs in `x`.
182
              This array has the same shape as `values` and has the default
183
              array index data type.
184
    """
185
    if not isinstance(x, dpt.usm_ndarray):
1!
186
        raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
×
187
    array_api_dev = x.device
1✔
188
    exec_q = array_api_dev.sycl_queue
1✔
189
    x_usm_type = x.usm_type
1✔
190
    if x.ndim == 1:
1✔
191
        fx = x
1✔
192
    else:
193
        fx = dpt.reshape(x, (x.size,), order="C")
1✔
194
    ind_dt = default_device_index_type(exec_q)
1✔
195
    if fx.size == 0:
1✔
196
        return UniqueCountsResult(fx, dpt.empty_like(fx, dtype=ind_dt))
1✔
197
    s = dpt.empty_like(fx, order="C")
1✔
198

199
    _manager = du.SequentialOrderManager[exec_q]
1✔
200
    dep_evs = _manager.submitted_events
1✔
201
    if fx.flags.c_contiguous:
1✔
202
        ht_ev, sort_ev = _sort_ascending(
1✔
203
            src=fx,
204
            trailing_dims_to_sort=1,
205
            dst=s,
206
            sycl_queue=exec_q,
207
            depends=dep_evs,
208
        )
209
        _manager.add_event_pair(ht_ev, sort_ev)
1✔
210
    else:
211
        tmp = dpt.empty_like(fx, order="C")
1✔
212
        ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray(
1✔
213
            src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs
214
        )
215
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
216
        ht_ev, sort_ev = _sort_ascending(
1✔
217
            src=tmp,
218
            dst=s,
219
            trailing_dims_to_sort=1,
220
            sycl_queue=exec_q,
221
            depends=[copy_ev],
222
        )
223
        _manager.add_event_pair(ht_ev, sort_ev)
1✔
224
    unique_mask = dpt.empty(s.shape, dtype="?", sycl_queue=exec_q)
1✔
225
    ht_ev, uneq_ev = _not_equal(
1✔
226
        src1=s[:-1],
227
        src2=s[1:],
228
        dst=unique_mask[1:],
229
        sycl_queue=exec_q,
230
        depends=[sort_ev],
231
    )
232
    _manager.add_event_pair(ht_ev, uneq_ev)
1✔
233
    # no dependency, since we write into new allocation
234
    ht_ev, one_ev = _full_usm_ndarray(
1✔
235
        fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
236
    )
237
    _manager.add_event_pair(ht_ev, one_ev)
1✔
238
    cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
1✔
239
    # synchronizing call
240
    n_uniques = mask_positions(
1✔
241
        unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev]
242
    )
243
    if n_uniques == fx.size:
1✔
244
        return UniqueCountsResult(
1✔
245
            s,
246
            dpt.ones(
247
                n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
248
            ),
249
        )
250
    unique_vals = dpt.empty(
1✔
251
        n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
252
    )
253
    # populate unique values
254
    ht_ev, ex_e = _extract(
1✔
255
        src=s,
256
        cumsum=cumsum,
257
        axis_start=0,
258
        axis_end=1,
259
        dst=unique_vals,
260
        sycl_queue=exec_q,
261
    )
262
    _manager.add_event_pair(ht_ev, ex_e)
1✔
263
    unique_counts = dpt.empty(
1✔
264
        n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
265
    )
266
    idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q)
1✔
267
    # writing into new allocation, no dependency
268
    ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q)
1✔
269
    _manager.add_event_pair(ht_ev, id_ev)
1✔
270
    ht_ev, extr_ev = _extract(
1✔
271
        src=idx,
272
        cumsum=cumsum,
273
        axis_start=0,
274
        axis_end=1,
275
        dst=unique_counts[:-1],
276
        sycl_queue=exec_q,
277
        depends=[id_ev],
278
    )
279
    _manager.add_event_pair(ht_ev, extr_ev)
1✔
280
    # no dependency, writing into disjoint segmenent of new allocation
281
    ht_ev, set_ev = _full_usm_ndarray(
1✔
282
        x.size, dst=unique_counts[-1], sycl_queue=exec_q
283
    )
284
    _manager.add_event_pair(ht_ev, set_ev)
1✔
285
    _counts = dpt.empty_like(unique_counts[1:])
1✔
286
    ht_ev, sub_ev = _subtract(
1✔
287
        src1=unique_counts[1:],
288
        src2=unique_counts[:-1],
289
        dst=_counts,
290
        sycl_queue=exec_q,
291
        depends=[set_ev, extr_ev],
292
    )
293
    _manager.add_event_pair(ht_ev, sub_ev)
1✔
294
    return UniqueCountsResult(unique_vals, _counts)
1✔
295

296

297
def unique_inverse(x):
1✔
298
    """unique_inverse
299

300
    Returns the unique elements of an input array x and the indices from the
301
    set of unique elements that reconstruct `x`.
302

303
    Args:
304
        x (usm_ndarray):
305
            input array. Inputs with more than one dimension are flattened.
306
    Returns:
307
        tuple[usm_ndarray, usm_ndarray]
308
            a namedtuple `(values, inverse_indices)` whose
309

310
            * first element has the field name `values` and is an array
311
              containing the unique elements of `x`. The array has the same
312
              data type as `x`.
313
            * second element has the field name `inverse_indices` and is an
314
              array containing the indices of values that reconstruct `x`.
315
              The array has the same shape as `x` and has the default array
316
              index data type.
317
    """
318
    if not isinstance(x, dpt.usm_ndarray):
1!
319
        raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
×
320
    array_api_dev = x.device
1✔
321
    exec_q = array_api_dev.sycl_queue
1✔
322
    x_usm_type = x.usm_type
1✔
323
    ind_dt = default_device_index_type(exec_q)
1✔
324
    if x.ndim == 1:
1✔
325
        fx = x
1✔
326
    else:
327
        fx = dpt.reshape(x, (x.size,), order="C")
1✔
328
    sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C")
1✔
329
    unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C")
1✔
330
    if fx.size == 0:
1✔
331
        return UniqueInverseResult(fx, dpt.reshape(unsorting_ids, x.shape))
1✔
332

333
    _manager = du.SequentialOrderManager[exec_q]
1✔
334
    dep_evs = _manager.submitted_events
1✔
335
    if fx.flags.c_contiguous:
1✔
336
        ht_ev, sort_ev = _argsort_ascending(
1✔
337
            src=fx,
338
            trailing_dims_to_sort=1,
339
            dst=sorting_ids,
340
            sycl_queue=exec_q,
341
            depends=dep_evs,
342
        )
343
        _manager.add_event_pair(ht_ev, sort_ev)
1✔
344
    else:
345
        tmp = dpt.empty_like(fx, order="C")
1✔
346
        ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray(
1✔
347
            src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs
348
        )
349
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
350
        ht_ev, sort_ev = _argsort_ascending(
1✔
351
            src=tmp,
352
            trailing_dims_to_sort=1,
353
            dst=sorting_ids,
354
            sycl_queue=exec_q,
355
            depends=[copy_ev],
356
        )
357
        _manager.add_event_pair(ht_ev, sort_ev)
1✔
358
    ht_ev, argsort_ev = _argsort_ascending(
1✔
359
        src=sorting_ids,
360
        trailing_dims_to_sort=1,
361
        dst=unsorting_ids,
362
        sycl_queue=exec_q,
363
        depends=[sort_ev],
364
    )
365
    _manager.add_event_pair(ht_ev, argsort_ev)
1✔
366
    s = dpt.empty_like(fx)
1✔
367
    # s = fx[sorting_ids]
368
    ht_ev, take_ev = _take(
1✔
369
        src=fx,
370
        ind=(sorting_ids,),
371
        dst=s,
372
        axis_start=0,
373
        mode=0,
374
        sycl_queue=exec_q,
375
        depends=[sort_ev],
376
    )
377
    _manager.add_event_pair(ht_ev, take_ev)
1✔
378
    unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q)
1✔
379
    ht_ev, uneq_ev = _not_equal(
1✔
380
        src1=s[:-1],
381
        src2=s[1:],
382
        dst=unique_mask[1:],
383
        sycl_queue=exec_q,
384
        depends=[take_ev],
385
    )
386
    _manager.add_event_pair(ht_ev, uneq_ev)
1✔
387
    # no dependency
388
    ht_ev, one_ev = _full_usm_ndarray(
1✔
389
        fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
390
    )
391
    _manager.add_event_pair(ht_ev, one_ev)
1✔
392
    cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
1✔
393
    # synchronizing call
394
    n_uniques = mask_positions(
1✔
395
        unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev]
396
    )
397
    if n_uniques == fx.size:
1✔
398
        return UniqueInverseResult(s, dpt.reshape(unsorting_ids, x.shape))
1✔
399
    unique_vals = dpt.empty(
1✔
400
        n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
401
    )
402
    ht_ev, uv_ev = _extract(
1✔
403
        src=s,
404
        cumsum=cumsum,
405
        axis_start=0,
406
        axis_end=1,
407
        dst=unique_vals,
408
        sycl_queue=exec_q,
409
    )
410
    _manager.add_event_pair(ht_ev, uv_ev)
1✔
411
    cum_unique_counts = dpt.empty(
1✔
412
        n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
413
    )
414
    idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q)
1✔
415
    ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q)
1✔
416
    _manager.add_event_pair(ht_ev, id_ev)
1✔
417
    ht_ev, extr_ev = _extract(
1✔
418
        src=idx,
419
        cumsum=cumsum,
420
        axis_start=0,
421
        axis_end=1,
422
        dst=cum_unique_counts[:-1],
423
        sycl_queue=exec_q,
424
        depends=[id_ev],
425
    )
426
    _manager.add_event_pair(ht_ev, extr_ev)
1✔
427
    ht_ev, set_ev = _full_usm_ndarray(
1✔
428
        x.size, dst=cum_unique_counts[-1], sycl_queue=exec_q
429
    )
430
    _manager.add_event_pair(ht_ev, set_ev)
1✔
431
    _counts = dpt.empty_like(cum_unique_counts[1:])
1✔
432
    ht_ev, sub_ev = _subtract(
1✔
433
        src1=cum_unique_counts[1:],
434
        src2=cum_unique_counts[:-1],
435
        dst=_counts,
436
        sycl_queue=exec_q,
437
        depends=[set_ev, extr_ev],
438
    )
439
    _manager.add_event_pair(ht_ev, sub_ev)
1✔
440

441
    inv = dpt.empty_like(x, dtype=ind_dt, order="C")
1✔
442
    ht_ev, ssl_ev = _searchsorted_left(
1✔
443
        hay=unique_vals,
444
        needles=x,
445
        positions=inv,
446
        sycl_queue=exec_q,
447
        depends=[
448
            uv_ev,
449
        ],
450
    )
451
    _manager.add_event_pair(ht_ev, ssl_ev)
1✔
452

453
    return UniqueInverseResult(unique_vals, inv)
1✔
454

455

456
def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
1✔
457
    """unique_all(x)
458

459
    Returns the unique elements of an input array `x`, the first occurring
460
    indices for each unique element in `x`, the indices from the set of unique
461
    elements that reconstruct `x`, and the corresponding counts for each
462
    unique element in `x`.
463

464
    Args:
465
        x (usm_ndarray):
466
            input array. Inputs with more than one dimension are flattened.
467
    Returns:
468
        tuple[usm_ndarray, usm_ndarray, usm_ndarray, usm_ndarray]
469
            a namedtuple `(values, indices, inverse_indices, counts)` whose
470

471
            * first element has the field name `values` and is an array
472
              containing the unique elements of `x`. The array has the same
473
              data type as `x`.
474
            * second element has the field name `indices` and is an array
475
              the indices (of first occurrences) of `x` that result in
476
              `values`. The array has the same shape as `values` and has the
477
              default array index data type.
478
            * third element has the field name `inverse_indices` and is an
479
              array containing the indices of values that reconstruct `x`.
480
              The array has the same shape as `x` and has the default array
481
              index data type.
482
            * fourth element has the field name `counts` and is an array
483
              containing the number of times each unique element occurs in `x`.
484
              This array has the same shape as `values` and has the default
485
              array index data type.
486
    """
487
    if not isinstance(x, dpt.usm_ndarray):
1!
488
        raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
×
489
    array_api_dev = x.device
1✔
490
    exec_q = array_api_dev.sycl_queue
1✔
491
    x_usm_type = x.usm_type
1✔
492
    ind_dt = default_device_index_type(exec_q)
1✔
493
    if x.ndim == 1:
1✔
494
        fx = x
1✔
495
    else:
496
        fx = dpt.reshape(x, (x.size,), order="C")
1✔
497
    sorting_ids = dpt.empty_like(fx, dtype=ind_dt, order="C")
1✔
498
    unsorting_ids = dpt.empty_like(sorting_ids, dtype=ind_dt, order="C")
1✔
499
    if fx.size == 0:
1✔
500
        # original array contains no data
501
        # so it can be safely returned as values
502
        return UniqueAllResult(
1✔
503
            fx,
504
            sorting_ids,
505
            dpt.reshape(unsorting_ids, x.shape),
506
            dpt.empty_like(fx, dtype=ind_dt),
507
        )
508
    _manager = du.SequentialOrderManager[exec_q]
1✔
509
    dep_evs = _manager.submitted_events
1✔
510
    if fx.flags.c_contiguous:
1✔
511
        ht_ev, sort_ev = _argsort_ascending(
1✔
512
            src=fx,
513
            trailing_dims_to_sort=1,
514
            dst=sorting_ids,
515
            sycl_queue=exec_q,
516
            depends=dep_evs,
517
        )
518
        _manager.add_event_pair(ht_ev, sort_ev)
1✔
519
    else:
520
        tmp = dpt.empty_like(fx, order="C")
1✔
521
        ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray(
1✔
522
            src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs
523
        )
524
        _manager.add_event_pair(ht_ev, copy_ev)
1✔
525
        ht_ev, sort_ev = _argsort_ascending(
1✔
526
            src=tmp,
527
            trailing_dims_to_sort=1,
528
            dst=sorting_ids,
529
            sycl_queue=exec_q,
530
            depends=[copy_ev],
531
        )
532
        _manager.add_event_pair(ht_ev, sort_ev)
1✔
533
    ht_ev, args_ev = _argsort_ascending(
1✔
534
        src=sorting_ids,
535
        trailing_dims_to_sort=1,
536
        dst=unsorting_ids,
537
        sycl_queue=exec_q,
538
        depends=[sort_ev],
539
    )
540
    _manager.add_event_pair(ht_ev, args_ev)
1✔
541
    s = dpt.empty_like(fx)
1✔
542
    # s = fx[sorting_ids]
543
    ht_ev, take_ev = _take(
1✔
544
        src=fx,
545
        ind=(sorting_ids,),
546
        dst=s,
547
        axis_start=0,
548
        mode=0,
549
        sycl_queue=exec_q,
550
        depends=[sort_ev],
551
    )
552
    _manager.add_event_pair(ht_ev, take_ev)
1✔
553
    unique_mask = dpt.empty(fx.shape, dtype="?", sycl_queue=exec_q)
1✔
554
    ht_ev, uneq_ev = _not_equal(
1✔
555
        src1=s[:-1],
556
        src2=s[1:],
557
        dst=unique_mask[1:],
558
        sycl_queue=exec_q,
559
        depends=[take_ev],
560
    )
561
    _manager.add_event_pair(ht_ev, uneq_ev)
1✔
562
    ht_ev, one_ev = _full_usm_ndarray(
1✔
563
        fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
564
    )
565
    _manager.add_event_pair(ht_ev, one_ev)
1✔
566
    cumsum = dpt.empty(unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q)
1✔
567
    # synchronizing call
568
    n_uniques = mask_positions(
1✔
569
        unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev]
570
    )
571
    if n_uniques == fx.size:
1✔
572
        _counts = dpt.ones(
1✔
573
            n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
574
        )
575
        return UniqueAllResult(
1✔
576
            s,
577
            sorting_ids,
578
            dpt.reshape(unsorting_ids, x.shape),
579
            _counts,
580
        )
581
    unique_vals = dpt.empty(
1✔
582
        n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
583
    )
584
    ht_ev, uv_ev = _extract(
1✔
585
        src=s,
586
        cumsum=cumsum,
587
        axis_start=0,
588
        axis_end=1,
589
        dst=unique_vals,
590
        sycl_queue=exec_q,
591
    )
592
    _manager.add_event_pair(ht_ev, uv_ev)
1✔
593
    cum_unique_counts = dpt.empty(
1✔
594
        n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
595
    )
596
    idx = dpt.empty(x.size, dtype=ind_dt, sycl_queue=exec_q)
1✔
597
    ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q)
1✔
598
    _manager.add_event_pair(ht_ev, id_ev)
1✔
599
    ht_ev, extr_ev = _extract(
1✔
600
        src=idx,
601
        cumsum=cumsum,
602
        axis_start=0,
603
        axis_end=1,
604
        dst=cum_unique_counts[:-1],
605
        sycl_queue=exec_q,
606
        depends=[id_ev],
607
    )
608
    _manager.add_event_pair(ht_ev, extr_ev)
1✔
609
    ht_ev, set_ev = _full_usm_ndarray(
1✔
610
        x.size, dst=cum_unique_counts[-1], sycl_queue=exec_q
611
    )
612
    _manager.add_event_pair(ht_ev, set_ev)
1✔
613
    _counts = dpt.empty_like(cum_unique_counts[1:])
1✔
614
    ht_ev, sub_ev = _subtract(
1✔
615
        src1=cum_unique_counts[1:],
616
        src2=cum_unique_counts[:-1],
617
        dst=_counts,
618
        sycl_queue=exec_q,
619
        depends=[set_ev, extr_ev],
620
    )
621
    _manager.add_event_pair(ht_ev, sub_ev)
1✔
622

623
    inv = dpt.empty_like(x, dtype=ind_dt, order="C")
1✔
624
    ht_ev, ssl_ev = _searchsorted_left(
1✔
625
        hay=unique_vals,
626
        needles=x,
627
        positions=inv,
628
        sycl_queue=exec_q,
629
        depends=[
630
            uv_ev,
631
        ],
632
    )
633
    _manager.add_event_pair(ht_ev, ssl_ev)
1✔
634
    return UniqueAllResult(
1✔
635
        unique_vals,
636
        sorting_ids[cum_unique_counts[:-1]],
637
        inv,
638
        _counts,
639
    )
640

641

642
def isin(
1✔
643
    x: Union[dpt.usm_ndarray, int, float, complex, bool],
644
    test_elements: Union[dpt.usm_ndarray, int, float, complex, bool],
645
    /,
646
    *,
647
    invert: Optional[bool] = False,
648
) -> dpt.usm_ndarray:
649
    """isin(x, test_elements, /, *, invert=False)
650

651
    Tests `x in test_elements` for each element of `x`. Returns a boolean array
652
    with the same shape as `x` that is `True` where the element is in
653
    `test_elements`, `False` otherwise.
654

655
    Args:
656
        x (Union[usm_ndarray, bool, int, float, complex]):
657
            input element or elements.
658
        test_elements (Union[usm_ndarray, bool, int, float, complex]):
659
            elements against which to test each value of `x`.
660
        invert (Optional[bool]):
661
            if `True`, the output results are inverted, i.e., are equivalent to
662
            testing `x not in test_elements` for each element of `x`.
663
            Default: `False`.
664

665
    Returns:
666
        usm_ndarray:
667
            an array of the inclusion test results. The returned array has a
668
            boolean data type and the same shape as `x`.
669
    """
670
    q1, x_usm_type = _get_queue_usm_type(x)
1✔
671
    q2, test_usm_type = _get_queue_usm_type(test_elements)
1✔
672
    if q1 is None and q2 is None:
1✔
673
        raise du.ExecutionPlacementError(
1✔
674
            "Execution placement can not be unambiguously inferred "
675
            "from input arguments. "
676
            "One of the arguments must represent USM allocation and "
677
            "expose `__sycl_usm_array_interface__` property"
678
        )
679
    if q1 is None:
1✔
680
        exec_q = q2
1✔
681
        res_usm_type = test_usm_type
1✔
682
    elif q2 is None:
1✔
683
        exec_q = q1
1✔
684
        res_usm_type = x_usm_type
1✔
685
    else:
686
        exec_q = du.get_execution_queue((q1, q2))
1✔
687
        if exec_q is None:
1!
NEW
688
            raise du.ExecutionPlacementError(
×
689
                "Execution placement can not be unambiguously inferred "
690
                "from input arguments."
691
            )
692
        res_usm_type = du.get_coerced_usm_type(
1✔
693
            (
694
                x_usm_type,
695
                test_usm_type,
696
            )
697
        )
698
    du.validate_usm_type(res_usm_type, allow_none=False)
1✔
699
    sycl_dev = exec_q.sycl_device
1✔
700

701
    x_dt = _get_dtype(x, sycl_dev)
1✔
702
    test_dt = _get_dtype(test_elements, sycl_dev)
1✔
703
    if not all(_validate_dtype(dt) for dt in (x_dt, test_dt)):
1!
NEW
704
        raise ValueError("Operands have unsupported data types")
×
705

706
    x_sh = _get_shape(x)
1✔
707
    if isinstance(test_elements, dpt.usm_ndarray) and test_elements.size == 0:
1✔
708
        if invert:
1✔
709
            return dpt.ones(
1✔
710
                x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q
711
            )
712
        else:
713
            return dpt.zeros(
1✔
714
                x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q
715
            )
716

717
    dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev)
1✔
718
    dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev)
1✔
719

720
    if not isinstance(x, dpt.usm_ndarray):
1✔
721
        x_arr = dpt.asarray(
1✔
722
            x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q
723
        )
724
    else:
725
        x_arr = x
1✔
726

727
    if not isinstance(test_elements, dpt.usm_ndarray):
1✔
728
        test_arr = dpt.asarray(
1✔
729
            test_elements, dtype=dt2, usm_type=res_usm_type, sycl_queue=exec_q
730
        )
731
    else:
732
        test_arr = test_elements
1✔
733

734
    _manager = du.SequentialOrderManager[exec_q]
1✔
735
    dep_evs = _manager.submitted_events
1✔
736

737
    if x_dt != dt:
1✔
738
        x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, exec_q)
1✔
739
        ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
1✔
740
            src=x_arr, dst=x_buf, sycl_queue=exec_q, depends=dep_evs
741
        )
742
        _manager.add_event_pair(ht_ev, ev)
1✔
743
    else:
744
        x_buf = x_arr
1✔
745

746
    if test_dt != dt:
1✔
747
        # copy into C-contiguous memory, because the array will be flattened
748
        test_buf = dpt.empty_like(
1✔
749
            test_arr, dtype=dt, order="C", usm_type=res_usm_type
750
        )
751
        ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray(
1✔
752
            src=test_arr, dst=test_buf, sycl_queue=exec_q, depends=dep_evs
753
        )
754
        _manager.add_event_pair(ht_ev, ev)
1✔
755
    else:
756
        test_buf = test_arr
1✔
757

758
    test_buf = dpt.reshape(test_buf, -1)
1✔
759
    test_buf = dpt.sort(test_buf)
1✔
760

761
    dst = dpt.empty_like(
1✔
762
        x_buf, dtype=dpt.bool, usm_type=res_usm_type, order="C"
763
    )
764

765
    dep_evs = _manager.submitted_events
1✔
766
    ht_ev, s_ev = _isin(
1✔
767
        needles=x_buf,
768
        hay=test_buf,
769
        dst=dst,
770
        sycl_queue=exec_q,
771
        invert=invert,
772
        depends=dep_evs,
773
    )
774
    _manager.add_event_pair(ht_ev, s_ev)
1✔
775
    return dst
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc