• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpnp / 13024285028

29 Jan 2025 03:51AM UTC coverage: 71.688% (+0.4%) from 71.325%
13024285028

Pull #2201

github

web-flow
Merge 6b8151655 into 0e479ccc3
Pull Request #2201: Implement extension for `dpnp.choose`

4532 of 9240 branches covered (49.05%)

Branch coverage included in aggregate %.

289 of 340 new or added lines in 5 files covered. (85.0%)

4 existing lines in 3 files now uncovered.

16981 of 20769 relevant lines covered (81.76%)

20593.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.86
/dpnp/backend/extensions/indexing/choose.cpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
//   this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
//   this list of conditions and the following disclaimer in the documentation
11
//   and/or other materials provided with the distribution.
12
//
13
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
//*****************************************************************************
25

26
#include <algorithm>
27
#include <cstddef>
28
#include <cstdint>
29
#include <memory>
30
#include <pybind11/pybind11.h>
31
#include <pybind11/stl.h>
32
#include <sycl/sycl.hpp>
33
#include <type_traits>
34
#include <utility>
35
#include <vector>
36

37
#include "choose_kernel.hpp"
38
#include "dpctl4pybind11.hpp"
39
#include "utils/indexing_utils.hpp"
40
#include "utils/memory_overlap.hpp"
41
#include "utils/output_validation.hpp"
42
#include "utils/sycl_alloc_utils.hpp"
43
#include "utils/type_dispatch.hpp"
44

45
namespace dpnp::extensions::indexing
46
{
47

48
namespace td_ns = dpctl::tensor::type_dispatch;
49

50
static kernels::choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
51
                                                          [td_ns::num_types];
52
static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
53
                                                          [td_ns::num_types];
54

55
namespace py = pybind11;
56

57
namespace detail
58
{
59

60
using host_ptrs_allocator_t =
61
    dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
62
using ptrs_t = std::vector<char *, host_ptrs_allocator_t>;
63
using host_ptrs_shp_t = std::shared_ptr<ptrs_t>;
64

65
host_ptrs_shp_t make_host_ptrs(sycl::queue &exec_q,
66
                               const std::vector<char *> &ptrs)
67
{
93✔
68
    host_ptrs_allocator_t ptrs_allocator(exec_q);
93✔
69
    host_ptrs_shp_t host_ptrs_shp =
93✔
70
        std::make_shared<ptrs_t>(ptrs.size(), ptrs_allocator);
93✔
71

72
    std::copy(ptrs.begin(), ptrs.end(), host_ptrs_shp->begin());
93✔
73

74
    return host_ptrs_shp;
93✔
75
}
93✔
76

77
using host_sz_allocator_t =
78
    dpctl::tensor::alloc_utils::usm_host_allocator<py::ssize_t>;
79
using sz_t = std::vector<py::ssize_t, host_sz_allocator_t>;
80
using host_sz_shp_t = std::shared_ptr<sz_t>;
81

82
host_sz_shp_t make_host_offsets(sycl::queue &exec_q,
83
                                const std::vector<py::ssize_t> &offsets)
84
{
93✔
85
    host_sz_allocator_t offsets_allocator(exec_q);
93✔
86
    host_sz_shp_t host_offsets_shp =
93✔
87
        std::make_shared<sz_t>(offsets.size(), offsets_allocator);
93✔
88

89
    std::copy(offsets.begin(), offsets.end(), host_offsets_shp->begin());
93✔
90

91
    return host_offsets_shp;
93✔
92
}
93✔
93

94
host_sz_shp_t make_host_shape_strides(sycl::queue &exec_q,
95
                                      py::ssize_t n_chcs,
96
                                      std::vector<py::ssize_t> &shape,
97
                                      std::vector<py::ssize_t> &inp_strides,
98
                                      std::vector<py::ssize_t> &dst_strides,
99
                                      std::vector<py::ssize_t> &chc_strides)
100
{
93✔
101
    auto nelems = shape.size();
93✔
102
    host_sz_allocator_t shape_strides_allocator(exec_q);
93✔
103
    host_sz_shp_t host_shape_strides_shp =
93✔
104
        std::make_shared<sz_t>(nelems * (3 + n_chcs), shape_strides_allocator);
93✔
105

106
    std::copy(shape.begin(), shape.end(), host_shape_strides_shp->begin());
93✔
107
    std::copy(inp_strides.begin(), inp_strides.end(),
93✔
108
              host_shape_strides_shp->begin() + nelems);
93✔
109
    std::copy(dst_strides.begin(), dst_strides.end(),
93✔
110
              host_shape_strides_shp->begin() + 2 * nelems);
93✔
111
    std::copy(chc_strides.begin(), chc_strides.end(),
93✔
112
              host_shape_strides_shp->begin() + 3 * nelems);
93✔
113

114
    return host_shape_strides_shp;
93✔
115
}
93✔
116

117
/* This function expects a queue and a non-trivial number of
118
   std::pairs of raw device pointers and host shared pointers
119
   (structured as <device_ptr, shared_ptr>),
120
   then enqueues a copy of the host shared pointer data into
121
   the device pointer.
122

123
   Assumes the device pointer addresses sufficient memory for
124
   the size of the host memory.
125
*/
126
template <typename... DevHostPairs>
127
std::vector<sycl::event> batched_copy(sycl::queue &exec_q,
128
                                      DevHostPairs &&...dev_host_pairs)
129
{
93✔
130
    constexpr std::size_t n = sizeof...(DevHostPairs);
93✔
131
    static_assert(n > 0, "batched_copy requires at least one argument");
93✔
132

133
    std::vector<sycl::event> copy_evs;
93✔
134
    copy_evs.reserve(n);
93✔
135
    (copy_evs.emplace_back(exec_q.copy(dev_host_pairs.second->data(),
93✔
136
                                       dev_host_pairs.first,
93✔
137
                                       dev_host_pairs.second->size())),
93✔
138
     ...);
93✔
139

140
    return copy_evs;
93✔
141
}
93✔
142

143
/* This function takes as input a queue, sycl::event dependencies,
144
   and a non-trivial number of shared_ptrs and moves them into
145
   a host_task lambda capture, ensuring their lifetime until the
146
   host_task executes.
147
*/
148
template <typename... Shps>
149
sycl::event async_shp_free(sycl::queue &exec_q,
150
                           const std::vector<sycl::event> &depends,
151
                           Shps &&...shps)
152
{
93✔
153
    constexpr std::size_t n = sizeof...(Shps);
93✔
154
    static_assert(n > 0, "async_shp_free requires at least one argument");
93✔
155

156
    const sycl::event &shared_ptr_cleanup_ev =
93✔
157
        exec_q.submit([&](sycl::handler &cgh) {
93✔
158
            cgh.depends_on(depends);
93✔
159
            cgh.host_task([capture = std::tuple(std::move(shps)...)]() {});
93✔
160
        });
93✔
161

162
    return shared_ptr_cleanup_ev;
93✔
163
}
93✔
164

165
// copied from dpctl, remove if a similar utility is ever exposed
166
std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,
167
                                                      const py::object &py_chcs)
168
{
97✔
169
    py::ssize_t chc_count = py::len(py_chcs);
97✔
170
    std::vector<dpctl::tensor::usm_ndarray> res;
97✔
171
    res.reserve(chc_count);
97✔
172

173
    for (py::ssize_t i = 0; i < chc_count; ++i) {
382✔
174
        py::object el_i = py_chcs[py::cast(i)];
285✔
175
        dpctl::tensor::usm_ndarray arr_i =
285✔
176
            py::cast<dpctl::tensor::usm_ndarray>(el_i);
285✔
177
        if (!dpctl::utils::queues_are_compatible(q, {arr_i})) {
285!
NEW
178
            throw py::value_error("Choice allocation queue is not compatible "
×
NEW
179
                                  "with execution queue");
×
NEW
180
        }
×
181
        res.push_back(arr_i);
285✔
182
    }
285✔
183

184
    return res;
97✔
185
}
97✔
186

187
} // namespace detail
188

189
std::pair<sycl::event, sycl::event>
190
    py_choose(const dpctl::tensor::usm_ndarray &src,
191
              const py::object &py_chcs,
192
              const dpctl::tensor::usm_ndarray &dst,
193
              uint8_t mode,
194
              sycl::queue &exec_q,
195
              const std::vector<sycl::event> &depends)
196
{
97✔
197
    std::vector<dpctl::tensor::usm_ndarray> chcs =
97✔
198
        detail::parse_py_chcs(exec_q, py_chcs);
97✔
199

200
    // Python list max size must fit into py_ssize_t
201
    py::ssize_t n_chcs = chcs.size();
97✔
202

203
    if (n_chcs == 0) {
97!
NEW
204
        throw py::value_error("List of choices is empty.");
×
NEW
205
    }
×
206

207
    if (mode != 0 && mode != 1) {
97!
NEW
208
        throw py::value_error("Mode must be 0 or 1.");
×
NEW
209
    }
×
210

211
    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
97✔
212

213
    const dpctl::tensor::usm_ndarray chc_rep = chcs[0];
97✔
214

215
    int nd = src.get_ndim();
97✔
216
    int dst_nd = dst.get_ndim();
97✔
217
    int chc_nd = chc_rep.get_ndim();
97✔
218

219
    if (nd != dst_nd || nd != chc_nd) {
97!
NEW
220
        throw py::value_error("Array shapes are not consistent");
×
NEW
221
    }
×
222

223
    const py::ssize_t *src_shape = src.get_shape_raw();
97✔
224
    const py::ssize_t *dst_shape = dst.get_shape_raw();
97✔
225
    const py::ssize_t *chc_shape = chc_rep.get_shape_raw();
97✔
226

227
    size_t nelems = src.get_size();
97✔
228
    bool shapes_equal = std::equal(src_shape, src_shape + nd, dst_shape);
97✔
229
    shapes_equal &= std::equal(src_shape, src_shape + nd, chc_shape);
97✔
230

231
    if (!shapes_equal) {
97!
NEW
232
        throw py::value_error("Array shapes don't match.");
×
NEW
233
    }
×
234

235
    if (nelems == 0) {
97✔
236
        return std::make_pair(sycl::event{}, sycl::event{});
4✔
237
    }
4✔
238

239
    char *src_data = src.get_data();
93✔
240
    char *dst_data = dst.get_data();
93✔
241

242
    if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
93!
NEW
243
        throw py::value_error(
×
NEW
244
            "Execution queue is not compatible with allocation queues");
×
NEW
245
    }
×
246

247
    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
93✔
248
    if (overlap(src, dst)) {
93!
NEW
249
        throw py::value_error("Array memory overlap.");
×
NEW
250
    }
×
251

252
    // trivial offsets as choose does not apply stride
253
    // simplification, but may in the future
254
    constexpr py::ssize_t src_offset = py::ssize_t(0);
93✔
255
    constexpr py::ssize_t dst_offset = py::ssize_t(0);
93✔
256

257
    int src_typenum = src.get_typenum();
93✔
258
    int dst_typenum = dst.get_typenum();
93✔
259
    int chc_typenum = chc_rep.get_typenum();
93✔
260

261
    auto array_types = td_ns::usm_ndarray_types();
93✔
262
    int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
93✔
263
    int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);
93✔
264
    int chc_type_id = array_types.typenum_to_lookup_id(chc_typenum);
93✔
265

266
    if (chc_type_id != dst_type_id) {
93!
NEW
267
        throw py::type_error("Output and choice data types are not the same.");
×
NEW
268
    }
×
269

270
    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems);
93✔
271

272
    std::vector<char *> chc_ptrs;
93✔
273
    chc_ptrs.reserve(n_chcs);
93✔
274

275
    std::vector<py::ssize_t> chc_offsets;
93✔
276
    chc_offsets.reserve(n_chcs);
93✔
277

278
    auto sh_nelems = std::max<int>(nd, 1);
93✔
279
    std::vector<py::ssize_t> chc_strides(n_chcs * sh_nelems, 0);
93✔
280

281
    for (auto i = 0; i < n_chcs; ++i) {
355✔
282
        dpctl::tensor::usm_ndarray chc_ = chcs[i];
262✔
283

284
        // ndim, type, and shape are checked against the first array
285
        if (i > 0) {
262✔
286
            if (!(chc_.get_ndim() == nd)) {
169!
NEW
287
                throw py::value_error(
×
NEW
288
                    "Choice array dimensions are not the same");
×
NEW
289
            }
×
290

291
            if (!(chc_type_id ==
169!
292
                  array_types.typenum_to_lookup_id(chc_.get_typenum()))) {
169✔
NEW
293
                throw py::type_error(
×
NEW
294
                    "Choice array data types are not all the same.");
×
NEW
295
            }
×
296

297
            const py::ssize_t *chc_shape_ = chc_.get_shape_raw();
169✔
298
            if (!std::equal(chc_shape_, chc_shape_ + nd, chc_shape)) {
169!
NEW
299
                throw py::value_error("Choice shapes are not all equal.");
×
NEW
300
            }
×
301
        }
169✔
302

303
        // check for overlap with destination
304
        if (overlap(dst, chc_)) {
262!
NEW
305
            throw py::value_error(
×
NEW
306
                "Arrays index overlapping segments of memory");
×
NEW
307
        }
×
308

309
        char *chc_data = chc_.get_data();
262✔
310

311
        if (nd > 0) {
262✔
312
            auto chc_strides_ = chc_.get_strides_vector();
261✔
313
            std::copy(chc_strides_.begin(), chc_strides_.end(),
261✔
314
                      chc_strides.begin() + i * nd);
261✔
315
        }
261✔
316

317
        chc_ptrs.push_back(chc_data);
262✔
318
        chc_offsets.push_back(py::ssize_t(0));
262✔
319
    }
262✔
320

321
    auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id]
93✔
322
                   : choose_wrap_dispatch_table[src_type_id][chc_type_id];
93✔
323

324
    if (fn == nullptr) {
93!
NEW
325
        throw std::runtime_error("Indices must be integer type, got " +
×
NEW
326
                                 std::to_string(src_type_id));
×
NEW
327
    }
×
328

329
    auto packed_chc_ptrs =
93✔
330
        dpctl::tensor::alloc_utils::smart_malloc_device<char *>(n_chcs, exec_q);
93✔
331

332
    // packed_shapes_strides = [common shape,
333
    //                          src.strides,
334
    //                          dst.strides,
335
    //                          chcs[0].strides,
336
    //                          ...,
337
    //                          chcs[n_chcs].strides]
338
    auto packed_shapes_strides =
93✔
339
        dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t>(
93✔
340
            (3 + n_chcs) * sh_nelems, exec_q);
93✔
341

342
    auto packed_chc_offsets =
93✔
343
        dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t>(n_chcs,
93✔
344
                                                                     exec_q);
93✔
345

346
    std::vector<sycl::event> host_task_events;
93✔
347
    host_task_events.reserve(2);
93✔
348

349
    std::vector<sycl::event> pack_deps;
93✔
350
    std::vector<py::ssize_t> common_shape;
93✔
351
    std::vector<py::ssize_t> src_strides;
93✔
352
    std::vector<py::ssize_t> dst_strides;
93✔
353
    if (nd == 0) {
93✔
354
        // special case where all inputs are scalars
355
        // need to pass src, dst shape=1 and strides=0
356
        // chc_strides already initialized to 0 so ignore
357
        common_shape = {1};
1✔
358
        src_strides = {0};
1✔
359
        dst_strides = {0};
1✔
360
    }
1✔
361
    else {
92✔
362
        common_shape = src.get_shape_vector();
92✔
363
        src_strides = src.get_strides_vector();
92✔
364
        dst_strides = dst.get_strides_vector();
92✔
365
    }
92✔
366

367
    auto host_chc_ptrs = detail::make_host_ptrs(exec_q, chc_ptrs);
93✔
368
    auto host_chc_offsets = detail::make_host_offsets(exec_q, chc_offsets);
93✔
369
    auto host_shape_strides = detail::make_host_shape_strides(
93✔
370
        exec_q, n_chcs, common_shape, src_strides, dst_strides, chc_strides);
93✔
371

372
    pack_deps = detail::batched_copy(
93✔
373
        exec_q, std::make_pair(packed_chc_ptrs.get(), host_chc_ptrs),
93✔
374
        std::make_pair(packed_chc_offsets.get(), host_chc_offsets),
93✔
375
        std::make_pair(packed_shapes_strides.get(), host_shape_strides));
93✔
376

377
    host_task_events.push_back(
93✔
378
        detail::async_shp_free(exec_q, pack_deps, host_chc_ptrs,
93✔
379
                               host_chc_offsets, host_shape_strides));
93✔
380

381
    std::vector<sycl::event> all_deps;
93✔
382
    all_deps.reserve(depends.size() + pack_deps.size());
93✔
383
    all_deps.insert(std::end(all_deps), std::begin(pack_deps),
93✔
384
                    std::end(pack_deps));
93✔
385
    all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
93✔
386

387
    sycl::event choose_generic_ev =
93✔
388
        fn(exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides.get(),
93✔
389
           src_data, dst_data, packed_chc_ptrs.get(), src_offset, dst_offset,
93✔
390
           packed_chc_offsets.get(), all_deps);
93✔
391

392
    // async_smart_free releases owners
393
    sycl::event temporaries_cleanup_ev =
93✔
394
        dpctl::tensor::alloc_utils::async_smart_free(
93✔
395
            exec_q, {choose_generic_ev}, packed_chc_ptrs, packed_shapes_strides,
93✔
396
            packed_chc_offsets);
93✔
397

398
    host_task_events.push_back(temporaries_cleanup_ev);
93✔
399

400
    using dpctl::utils::keep_args_alive;
93✔
401
    sycl::event arg_cleanup_ev =
93✔
402
        keep_args_alive(exec_q, {src, py_chcs, dst}, host_task_events);
93✔
403

404
    return std::make_pair(arg_cleanup_ev, choose_generic_ev);
93✔
405
}
93✔
406

407
template <typename fnT, typename IndT, typename T, typename Index>
408
struct ChooseFactory
409
{
410
    fnT get()
411
    {
784✔
412
        if constexpr (std::is_integral<IndT>::value &&
413
                      !std::is_same<IndT, bool>::value) {
448✔
414
            fnT fn = kernels::choose_impl<Index, IndT, T>;
448✔
415
            return fn;
448✔
416
        }
417
        else {
336✔
418
            fnT fn = nullptr;
336✔
419
            return fn;
336✔
420
        }
336✔
421
    }
784✔
422
};
423

424
using dpctl::tensor::indexing_utils::ClipIndex;
425
using dpctl::tensor::indexing_utils::WrapIndex;
426

427
template <typename fnT, typename IndT, typename T>
428
using ChooseWrapFactory = ChooseFactory<fnT, IndT, T, WrapIndex<IndT>>;
429

430
template <typename fnT, typename IndT, typename T>
431
using ChooseClipFactory = ChooseFactory<fnT, IndT, T, ClipIndex<IndT>>;
432

433
void init_choose_dispatch_tables(void)
434
{
2✔
435
    using namespace td_ns;
2✔
436
    using kernels::choose_fn_ptr_t;
2✔
437

438
    DispatchTableBuilder<choose_fn_ptr_t, ChooseClipFactory, num_types>
2✔
439
        dtb_choose_clip;
2✔
440
    dtb_choose_clip.populate_dispatch_table(choose_clip_dispatch_table);
2✔
441

442
    DispatchTableBuilder<choose_fn_ptr_t, ChooseWrapFactory, num_types>
2✔
443
        dtb_choose_wrap;
2✔
444
    dtb_choose_wrap.populate_dispatch_table(choose_wrap_dispatch_table);
2✔
445

446
    return;
2✔
447
}
2✔
448

449
void init_choose(py::module_ m)
450
{
2✔
451
    dpnp::extensions::indexing::init_choose_dispatch_tables();
2✔
452

453
    m.def("_choose", &py_choose, "", py::arg("src"), py::arg("chcs"),
2✔
454
          py::arg("dst"), py::arg("mode"), py::arg("sycl_queue"),
2✔
455
          py::arg("depends") = py::list());
2✔
456

457
    return;
2✔
458
}
2✔
459

460
} // namespace dpnp::extensions::indexing
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc