• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpnp / 12896026246

21 Jan 2025 09:17PM UTC coverage: 71.211% (+0.4%) from 70.856%
12896026246

Pull #2201

github

web-flow
Merge 925b7d83b into 356184a29
Pull Request #2201: Implement extension for `dpnp.choose`

4568 of 9390 branches covered (48.65%)

Branch coverage included in aggregate %.

282 of 333 new or added lines in 5 files covered. (84.68%)

4 existing lines in 3 files now uncovered.

16935 of 20806 relevant lines covered (81.39%)

20542.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.45
/dpnp/backend/extensions/indexing/choose.cpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
//   this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
//   this list of conditions and the following disclaimer in the documentation
11
//   and/or other materials provided with the distribution.
12
//
13
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
//*****************************************************************************
25

26
#include <algorithm>
27
#include <cstddef>
28
#include <cstdint>
29
#include <memory>
30
#include <pybind11/pybind11.h>
31
#include <pybind11/stl.h>
32
#include <sycl/sycl.hpp>
33
#include <type_traits>
34
#include <utility>
35
#include <vector>
36

37
#include "choose_kernel.hpp"
38
#include "dpctl4pybind11.hpp"
39
#include "utils/indexing_utils.hpp"
40
#include "utils/memory_overlap.hpp"
41
#include "utils/output_validation.hpp"
42
#include "utils/sycl_alloc_utils.hpp"
43
#include "utils/type_dispatch.hpp"
44

45
namespace dpnp::extensions::indexing
46
{
47

48
namespace td_ns = dpctl::tensor::type_dispatch;
49

50
static kernels::choose_fn_ptr_t choose_clip_dispatch_table[td_ns::num_types]
51
                                                          [td_ns::num_types];
52
static kernels::choose_fn_ptr_t choose_wrap_dispatch_table[td_ns::num_types]
53
                                                          [td_ns::num_types];
54

55
namespace py = pybind11;
56

57
std::vector<sycl::event>
58
    _populate_choose_kernel_params(sycl::queue &exec_q,
59
                                   std::vector<sycl::event> &host_task_events,
60
                                   char **device_chc_ptrs,
61
                                   py::ssize_t *device_shape_strides,
62
                                   py::ssize_t *device_chc_offsets,
63
                                   const py::ssize_t *shape,
64
                                   int shape_len,
65
                                   std::vector<py::ssize_t> &inp_strides,
66
                                   std::vector<py::ssize_t> &dst_strides,
67
                                   std::vector<py::ssize_t> &chc_strides,
68
                                   std::vector<char *> &chc_ptrs,
69
                                   std::vector<py::ssize_t> &chc_offsets,
70
                                   py::ssize_t n_chcs)
71
{
93✔
72
    using ptr_host_allocator_T =
93✔
73
        dpctl::tensor::alloc_utils::usm_host_allocator<char *>;
93✔
74
    using ptrT = std::vector<char *, ptr_host_allocator_T>;
93✔
75

76
    ptr_host_allocator_T ptr_allocator(exec_q);
93✔
77
    std::shared_ptr<ptrT> host_chc_ptrs_shp =
93✔
78
        std::make_shared<ptrT>(n_chcs, ptr_allocator);
93✔
79

80
    using usm_host_allocatorT =
93✔
81
        dpctl::tensor::alloc_utils::usm_host_allocator<py::ssize_t>;
93✔
82
    using shT = std::vector<py::ssize_t, usm_host_allocatorT>;
93✔
83

84
    usm_host_allocatorT sz_allocator(exec_q);
93✔
85
    std::shared_ptr<shT> host_shape_strides_shp =
93✔
86
        std::make_shared<shT>(shape_len * (3 + n_chcs), sz_allocator);
93✔
87

88
    std::shared_ptr<shT> host_chc_offsets_shp =
93✔
89
        std::make_shared<shT>(n_chcs, sz_allocator);
93✔
90

91
    std::copy(shape, shape + shape_len, host_shape_strides_shp->begin());
93✔
92
    std::copy(inp_strides.begin(), inp_strides.end(),
93✔
93
              host_shape_strides_shp->begin() + shape_len);
93✔
94
    std::copy(dst_strides.begin(), dst_strides.end(),
93✔
95
              host_shape_strides_shp->begin() + 2 * shape_len);
93✔
96
    std::copy(chc_strides.begin(), chc_strides.end(),
93✔
97
              host_shape_strides_shp->begin() + 3 * shape_len);
93✔
98

99
    std::copy(chc_ptrs.begin(), chc_ptrs.end(), host_chc_ptrs_shp->begin());
93✔
100
    std::copy(chc_offsets.begin(), chc_offsets.end(),
93✔
101
              host_chc_offsets_shp->begin());
93✔
102

103
    const sycl::event &device_chc_ptrs_copy_ev = exec_q.copy<char *>(
93✔
104
        host_chc_ptrs_shp->data(), device_chc_ptrs, host_chc_ptrs_shp->size());
93✔
105

106
    const sycl::event &device_shape_strides_copy_ev = exec_q.copy<py::ssize_t>(
93✔
107
        host_shape_strides_shp->data(), device_shape_strides,
93✔
108
        host_shape_strides_shp->size());
93✔
109

110
    const sycl::event &device_chc_offsets_copy_ev = exec_q.copy<py::ssize_t>(
93✔
111
        host_chc_offsets_shp->data(), device_chc_offsets,
93✔
112
        host_chc_offsets_shp->size());
93✔
113

114
    const sycl::event &shared_ptr_cleanup_ev =
93✔
115
        exec_q.submit([&](sycl::handler &cgh) {
93✔
116
            cgh.depends_on({device_chc_offsets_copy_ev,
93✔
117
                            device_shape_strides_copy_ev,
93✔
118
                            device_chc_ptrs_copy_ev});
93✔
119
            cgh.host_task([host_chc_offsets_shp, host_shape_strides_shp,
93✔
120
                           host_chc_ptrs_shp]() {});
93✔
121
        });
93✔
122
    host_task_events.push_back(shared_ptr_cleanup_ev);
93✔
123

124
    std::vector<sycl::event> param_pack_deps{device_chc_ptrs_copy_ev,
93✔
125
                                             device_shape_strides_copy_ev,
93✔
126
                                             device_chc_offsets_copy_ev};
93✔
127
    return param_pack_deps;
93✔
128
}
93✔
129

130
// copied from dpctl, remove if a similar utility is ever exposed
131
std::vector<dpctl::tensor::usm_ndarray> parse_py_chcs(const sycl::queue &q,
132
                                                      const py::object &py_chcs)
133
{
97✔
134
    py::ssize_t chc_count = py::len(py_chcs);
97✔
135
    std::vector<dpctl::tensor::usm_ndarray> res;
97✔
136
    res.reserve(chc_count);
97✔
137

138
    for (py::ssize_t i = 0; i < chc_count; ++i) {
382✔
139
        py::object el_i = py_chcs[py::cast(i)];
285✔
140
        dpctl::tensor::usm_ndarray arr_i =
285✔
141
            py::cast<dpctl::tensor::usm_ndarray>(el_i);
285✔
142
        if (!dpctl::utils::queues_are_compatible(q, {arr_i})) {
285!
NEW
143
            throw py::value_error("Choice allocation queue is not compatible "
×
NEW
144
                                  "with execution queue");
×
NEW
145
        }
×
146
        res.push_back(arr_i);
285✔
147
    }
285✔
148

149
    return res;
97✔
150
}
97✔
151

152
std::pair<sycl::event, sycl::event>
153
    py_choose(const dpctl::tensor::usm_ndarray &src,
154
              const py::object &py_chcs,
155
              const dpctl::tensor::usm_ndarray &dst,
156
              uint8_t mode,
157
              sycl::queue &exec_q,
158
              const std::vector<sycl::event> &depends)
159
{
97✔
160
    std::vector<dpctl::tensor::usm_ndarray> chcs =
97✔
161
        parse_py_chcs(exec_q, py_chcs);
97✔
162

163
    // Python list max size must fit into py_ssize_t
164
    py::ssize_t n_chcs = chcs.size();
97✔
165

166
    if (n_chcs == 0) {
97!
NEW
167
        throw py::value_error("List of choices is empty.");
×
NEW
168
    }
×
169

170
    if (mode != 0 && mode != 1) {
97!
NEW
171
        throw py::value_error("Mode must be 0 or 1.");
×
NEW
172
    }
×
173

174
    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
97✔
175

176
    const dpctl::tensor::usm_ndarray chc_rep = chcs[0];
97✔
177

178
    int nd = src.get_ndim();
97✔
179
    int dst_nd = dst.get_ndim();
97✔
180
    int chc_nd = chc_rep.get_ndim();
97✔
181

182
    if (nd != dst_nd || nd != chc_nd) {
97!
NEW
183
        throw py::value_error("Array shapes are not consistent");
×
NEW
184
    }
×
185

186
    const py::ssize_t *src_shape = src.get_shape_raw();
97✔
187
    const py::ssize_t *dst_shape = dst.get_shape_raw();
97✔
188
    const py::ssize_t *chc_shape = chc_rep.get_shape_raw();
97✔
189

190
    size_t nelems = src.get_size();
97✔
191
    bool shapes_equal = std::equal(src_shape, src_shape + nd, dst_shape);
97✔
192
    shapes_equal &= std::equal(src_shape, src_shape + nd, chc_shape);
97✔
193

194
    if (!shapes_equal) {
97!
NEW
195
        throw py::value_error("Array shapes don't match.");
×
NEW
196
    }
×
197

198
    if (nelems == 0) {
97✔
199
        return std::make_pair(sycl::event{}, sycl::event{});
4✔
200
    }
4✔
201

202
    char *src_data = src.get_data();
93✔
203
    char *dst_data = dst.get_data();
93✔
204

205
    if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
93!
NEW
206
        throw py::value_error(
×
NEW
207
            "Execution queue is not compatible with allocation queues");
×
NEW
208
    }
×
209

210
    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
93✔
211
    if (overlap(src, dst)) {
93!
NEW
212
        throw py::value_error("Array memory overlap.");
×
NEW
213
    }
×
214

215
    // trivial offsets as choose does not apply stride
216
    // simplification, but may in the future
217
    constexpr py::ssize_t src_offset = py::ssize_t(0);
93✔
218
    constexpr py::ssize_t dst_offset = py::ssize_t(0);
93✔
219

220
    int src_typenum = src.get_typenum();
93✔
221
    int dst_typenum = dst.get_typenum();
93✔
222
    int chc_typenum = chc_rep.get_typenum();
93✔
223

224
    auto array_types = td_ns::usm_ndarray_types();
93✔
225
    int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
93✔
226
    int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);
93✔
227
    int chc_type_id = array_types.typenum_to_lookup_id(chc_typenum);
93✔
228

229
    if (chc_type_id != dst_type_id) {
93!
NEW
230
        throw py::type_error("Output and choice data types are not the same.");
×
NEW
231
    }
×
232

233
    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems);
93✔
234

235
    std::vector<char *> chc_ptrs;
93✔
236
    chc_ptrs.reserve(n_chcs);
93✔
237

238
    std::vector<py::ssize_t> chc_offsets;
93✔
239
    chc_offsets.reserve(n_chcs);
93✔
240

241
    auto sh_nelems = std::max<int>(nd, 1);
93✔
242
    std::vector<py::ssize_t> chc_strides(n_chcs * sh_nelems, 0);
93✔
243

244
    for (auto i = 0; i < n_chcs; ++i) {
355✔
245
        dpctl::tensor::usm_ndarray chc_ = chcs[i];
262✔
246

247
        // ndim, type, and shape are checked against the first array
248
        if (i > 0) {
262✔
249
            if (!(chc_.get_ndim() == nd)) {
169!
NEW
250
                throw py::value_error(
×
NEW
251
                    "Choice array dimensions are not the same");
×
NEW
252
            }
×
253

254
            if (!(chc_type_id ==
169!
255
                  array_types.typenum_to_lookup_id(chc_.get_typenum()))) {
169✔
NEW
256
                throw py::type_error(
×
NEW
257
                    "Choice array data types are not all the same.");
×
NEW
258
            }
×
259

260
            const py::ssize_t *chc_shape_ = chc_.get_shape_raw();
169✔
261
            if (!std::equal(chc_shape_, chc_shape_ + nd, chc_shape)) {
169!
NEW
262
                throw py::value_error("Choice shapes are not all equal.");
×
NEW
263
            }
×
264
        }
169✔
265

266
        // check for overlap with destination
267
        if (overlap(dst, chc_)) {
262!
NEW
268
            throw py::value_error(
×
NEW
269
                "Arrays index overlapping segments of memory");
×
NEW
270
        }
×
271

272
        char *chc_data = chc_.get_data();
262✔
273

274
        if (nd > 0) {
262✔
275
            auto chc_strides_ = chc_.get_strides_vector();
261✔
276
            std::copy(chc_strides_.begin(), chc_strides_.end(),
261✔
277
                      chc_strides.begin() + i * nd);
261✔
278
        }
261✔
279

280
        chc_ptrs.push_back(chc_data);
262✔
281
        chc_offsets.push_back(py::ssize_t(0));
262✔
282
    }
262✔
283

284
    auto fn = mode ? choose_clip_dispatch_table[src_type_id][chc_type_id]
93✔
285
                   : choose_wrap_dispatch_table[src_type_id][chc_type_id];
93✔
286

287
    if (fn == nullptr) {
93!
NEW
288
        throw std::runtime_error("Indices must be integer type, got " +
×
NEW
289
                                 std::to_string(src_type_id));
×
NEW
290
    }
×
291

292
    auto packed_chc_ptrs =
93✔
293
        dpctl::tensor::alloc_utils::smart_malloc_device<char *>(n_chcs, exec_q);
93✔
294

295
    // packed_shapes_strides = [common shape,
296
    //                          src.strides,
297
    //                          dst.strides,
298
    //                          chcs[0].strides,
299
    //                          ...,
300
    //                          chcs[n_chcs].strides]
301
    auto packed_shapes_strides =
93✔
302
        dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t>(
93✔
303
            (3 + n_chcs) * sh_nelems, exec_q);
93✔
304

305
    auto packed_chc_offsets =
93✔
306
        dpctl::tensor::alloc_utils::smart_malloc_device<py::ssize_t>(n_chcs,
93✔
307
                                                                     exec_q);
93✔
308

309
    std::vector<sycl::event> host_task_events;
93✔
310
    host_task_events.reserve(2);
93✔
311

312
    std::vector<sycl::event> pack_deps;
93✔
313
    if (nd == 0) {
93✔
314
        // special case where all inputs are scalars
315
        // need to pass src, dst shape=1 and strides=0
316
        // chc_strides already initialized to 0 so ignore
317
        std::array<py::ssize_t, 1> scalar_sh{1};
1✔
318
        std::vector<py::ssize_t> src_strides{0};
1✔
319
        std::vector<py::ssize_t> dst_strides{0};
1✔
320

321
        pack_deps = _populate_choose_kernel_params(
1✔
322
            exec_q, host_task_events, packed_chc_ptrs.get(),
1✔
323
            packed_shapes_strides.get(), packed_chc_offsets.get(),
1✔
324
            scalar_sh.data(), sh_nelems, src_strides, dst_strides, chc_strides,
1✔
325
            chc_ptrs, chc_offsets, n_chcs);
1✔
326
    }
1✔
327
    else {
92✔
328
        auto src_strides = src.get_strides_vector();
92✔
329
        auto dst_strides = dst.get_strides_vector();
92✔
330

331
        pack_deps = _populate_choose_kernel_params(
92✔
332
            exec_q, host_task_events, packed_chc_ptrs.get(),
92✔
333
            packed_shapes_strides.get(), packed_chc_offsets.get(), src_shape,
92✔
334
            sh_nelems, src_strides, dst_strides, chc_strides, chc_ptrs,
92✔
335
            chc_offsets, n_chcs);
92✔
336
    }
92✔
337

338
    std::vector<sycl::event> all_deps;
93✔
339
    all_deps.reserve(depends.size() + pack_deps.size());
93✔
340
    all_deps.insert(std::end(all_deps), std::begin(pack_deps),
93✔
341
                    std::end(pack_deps));
93✔
342
    all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
93✔
343

344
    sycl::event choose_generic_ev =
93✔
345
        fn(exec_q, nelems, n_chcs, sh_nelems, packed_shapes_strides.get(),
93✔
346
           src_data, dst_data, packed_chc_ptrs.get(), src_offset, dst_offset,
93✔
347
           packed_chc_offsets.get(), all_deps);
93✔
348

349
    // async_smart_free releases owners
350
    sycl::event temporaries_cleanup_ev =
93✔
351
        dpctl::tensor::alloc_utils::async_smart_free(
93✔
352
            exec_q, {choose_generic_ev}, packed_chc_ptrs, packed_shapes_strides,
93✔
353
            packed_chc_offsets);
93✔
354

355
    host_task_events.push_back(temporaries_cleanup_ev);
93✔
356

357
    using dpctl::utils::keep_args_alive;
93✔
358
    sycl::event arg_cleanup_ev =
93✔
359
        keep_args_alive(exec_q, {src, py_chcs, dst}, host_task_events);
93✔
360

361
    return std::make_pair(arg_cleanup_ev, choose_generic_ev);
93✔
362
}
93✔
363

364
template <typename fnT, typename IndT, typename T, typename Index>
365
struct ChooseFactory
366
{
367
    fnT get()
368
    {
784✔
369
        if constexpr (std::is_integral<IndT>::value &&
370
                      !std::is_same<IndT, bool>::value) {
448✔
371
            fnT fn = kernels::choose_impl<Index, IndT, T>;
448✔
372
            return fn;
448✔
373
        }
374
        else {
336✔
375
            fnT fn = nullptr;
336✔
376
            return fn;
336✔
377
        }
336✔
378
    }
784✔
379
};
380

381
using dpctl::tensor::indexing_utils::ClipIndex;
382
using dpctl::tensor::indexing_utils::WrapIndex;
383

384
template <typename fnT, typename IndT, typename T>
385
using ChooseWrapFactory = ChooseFactory<fnT, IndT, T, WrapIndex<IndT>>;
386

387
template <typename fnT, typename IndT, typename T>
388
using ChooseClipFactory = ChooseFactory<fnT, IndT, T, ClipIndex<IndT>>;
389

390
void init_choose_dispatch_tables(void)
391
{
2✔
392
    using namespace td_ns;
2✔
393
    using kernels::choose_fn_ptr_t;
2✔
394

395
    DispatchTableBuilder<choose_fn_ptr_t, ChooseClipFactory, num_types>
2✔
396
        dtb_choose_clip;
2✔
397
    dtb_choose_clip.populate_dispatch_table(choose_clip_dispatch_table);
2✔
398

399
    DispatchTableBuilder<choose_fn_ptr_t, ChooseWrapFactory, num_types>
2✔
400
        dtb_choose_wrap;
2✔
401
    dtb_choose_wrap.populate_dispatch_table(choose_wrap_dispatch_table);
2✔
402

403
    return;
2✔
404
}
2✔
405

406
void init_choose(py::module_ m)
407
{
2✔
408
    dpnp::extensions::indexing::init_choose_dispatch_tables();
2✔
409

410
    m.def("_choose", &py_choose, "", py::arg("src"), py::arg("chcs"),
2✔
411
          py::arg("dst"), py::arg("mode"), py::arg("sycl_queue"),
2✔
412
          py::arg("depends") = py::list());
2✔
413

414
    return;
2✔
415
}
2✔
416

417
} // namespace dpnp::extensions::indexing
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc