• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpnp / 16126000227

07 Jul 2025 07:24PM UTC coverage: 22.684% (-49.4%) from 72.051%
16126000227

Pull #2519

github

web-flow
Merge bd753a3a3 into 624f14f20
Pull Request #2519: tmp changes

889 of 9756 branches covered (9.11%)

Branch coverage included in aggregate %.

6317 of 22011 relevant lines covered (28.7%)

35.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

5.88
/dpnp/backend/extensions/lapack/gesv_batch.cpp
1
//*****************************************************************************
2
// Copyright (c) 2023-2025, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
//   this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
//   this list of conditions and the following disclaimer in the documentation
11
//   and/or other materials provided with the distribution.
12
//
13
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
//*****************************************************************************
25

26
#include <exception>
27
#include <stdexcept>
28

29
#include <pybind11/pybind11.h>
30

31
// dpctl tensor headers
32
#include "utils/type_utils.hpp"
33

34
#include "common_helpers.hpp"
35
#include "gesv.hpp"
36
#include "gesv_common_utils.hpp"
37
#include "types_matrix.hpp"
38

39
namespace dpnp::extensions::lapack
40
{
41
namespace mkl_lapack = oneapi::mkl::lapack;
42
namespace py = pybind11;
43
namespace type_utils = dpctl::tensor::type_utils;
44

45
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
46

47
typedef sycl::event (*gesv_batch_impl_fn_ptr_t)(
48
    sycl::queue &,
49
    const std::int64_t,
50
    const std::int64_t,
51
    const std::int64_t,
52
#if defined(USE_ONEMATH)
53
    const std::int64_t,
54
    const std::int64_t,
55
#endif // USE_ONEMATH
56
    char *,
57
    char *,
58
    const std::vector<sycl::event> &);
59

60
static gesv_batch_impl_fn_ptr_t
61
    gesv_batch_dispatch_vector[dpctl_td_ns::num_types];
62

63
template <typename T>
64
static sycl::event gesv_batch_impl(sycl::queue &exec_q,
65
                                   const std::int64_t n,
66
                                   const std::int64_t nrhs,
67
                                   const std::int64_t batch_size,
68
#if defined(USE_ONEMATH)
69
                                   const std::int64_t stride_a,
70
                                   const std::int64_t stride_b,
71
#endif // USE_ONEMATH
72
                                   char *in_a,
73
                                   char *in_b,
74
                                   const std::vector<sycl::event> &depends)
75
{
×
76
    type_utils::validate_type_for_device<T>(exec_q);
×
77

78
    T *a = reinterpret_cast<T *>(in_a);
×
79
    T *b = reinterpret_cast<T *>(in_b);
×
80

81
    const std::int64_t lda = std::max<size_t>(1UL, n);
×
82
    const std::int64_t ldb = std::max<size_t>(1UL, n);
×
83

84
    std::int64_t scratchpad_size = 0;
×
85
    sycl::event comp_event;
×
86
    std::int64_t *ipiv = nullptr;
×
87
    T *scratchpad = nullptr;
×
88

89
    std::stringstream error_msg;
×
90
    bool is_exception_caught = false;
×
91

92
#if defined(USE_ONEMATH)
93
    // Use transpose::T if the LU-factorized array is passed as C-contiguous.
94
    // For F-contiguous we use transpose::N.
95
    // Since gesv_batch takes F-contiguous as input, we use transpose::N.
96
    oneapi::mkl::transpose trans = oneapi::mkl::transpose::N;
97
    const std::int64_t stride_ipiv = n;
98

99
    scratchpad_size = std::max(
100
        mkl_lapack::getrs_batch_scratchpad_size<T>(exec_q, trans, n, nrhs, lda,
101
                                                   stride_a, stride_ipiv, ldb,
102
                                                   stride_b, batch_size),
103
        mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, n, n, lda, stride_a,
104
                                                   stride_ipiv, batch_size));
105

106
    scratchpad = helper::alloc_scratchpad<T>(scratchpad_size, exec_q);
107

108
    // pass batch_size * n to allocate the memory for a 2D array of pivot
109
    // indices
110
    try {
111
        ipiv = helper::alloc_ipiv(batch_size * n, exec_q);
112
    } catch (const std::exception &e) {
113
        if (scratchpad != nullptr)
114
            sycl_free_noexcept(scratchpad, exec_q);
115
        throw;
116
    }
117

118
    sycl::event getrf_batch_event;
119
    try {
120
        getrf_batch_event = mkl_lapack::getrf_batch(
121
            exec_q,
122
            n, // The order of each square matrix in the batch; (0 ≤ n).
123
               // It must be a non-negative integer.
124
            n, // The number of columns in each matrix in the batch; (0 ≤ n).
125
               // It must be a non-negative integer.
126
            a, // Pointer to the batch of square matrices, each of size (n x n).
127
            lda,      // The leading dimension of each matrix in the batch.
128
            stride_a, // Stride between consecutive matrices in the batch.
129
            ipiv, // Pointer to the array of pivot indices for each matrix in
130
                  // the batch.
131
            stride_ipiv, // Stride between pivot indices: Spacing between pivot
132
                         // arrays in 'ipiv'.
133
            batch_size,  // Stride between pivot index arrays in the batch.
134
            scratchpad,  // Pointer to scratchpad memory to be used by MKL
135
                         // routine for storing intermediate results.
136
            scratchpad_size, depends);
137

138
        comp_event = mkl_lapack::getrs_batch(
139
            exec_q,
140
            trans, // Specifies the operation: whether or not to transpose
141
                   // matrix A. Can be 'N' for no transpose, 'T' for transpose,
142
                   // and 'C' for conjugate transpose.
143
            n,     // The order of each square matrix A in the batch
144
                   // and the number of rows in each matrix B (0 ≤ n).
145
                   // It must be a non-negative integer.
146
            nrhs,  // The number of right-hand sides,
147
                   // i.e., the number of columns in each matrix B in the batch
148
                   // (0 ≤ nrhs).
149
            a,     // Pointer to the batch of square matrices A (n x n).
150
            lda,   // The leading dimension of each matrix A in the batch.
151
                   // It must be at least max(1, n).
152
            stride_a,    // Stride between individual matrices in the batch for
153
                         // matrix A.
154
            ipiv,        // Pointer to the batch of arrays of pivot indices.
155
            stride_ipiv, // Stride between pivot index arrays in the batch.
156
            b,           // Pointer to the batch of matrices B (n, nrhs).
157
            ldb,         // The leading dimension of each matrix B in the batch.
158
                         // Must be at least max(1, n).
159
            stride_b,    // Stride between individual matrices in the batch for
160
                         // matrix B.
161
            batch_size,  // The number of matrices in the batch.
162
            scratchpad,  // Pointer to scratchpad memory to be used by MKL
163
                         // routine for storing intermediate results.
164
            scratchpad_size, {getrf_batch_event});
165
    } catch (mkl_lapack::batch_error const &be) {
166
        // Get the indices of matrices within the batch that encountered an
167
        // error
168
        auto error_matrices_ids = be.ids();
169

170
        error_msg << "Singular matrix. Errors in matrices with IDs: ";
171
        for (size_t i = 0; i < error_matrices_ids.size(); ++i) {
172
            error_msg << error_matrices_ids[i];
173
            if (i < error_matrices_ids.size() - 1) {
174
                error_msg << ", ";
175
            }
176
        }
177
        error_msg << ".";
178

179
        if (scratchpad != nullptr)
180
            sycl_free_noexcept(scratchpad, exec_q);
181
        if (ipiv != nullptr)
182
            sycl_free_noexcept(ipiv, exec_q);
183

184
        throw LinAlgError(error_msg.str().c_str());
185
    } catch (mkl_lapack::exception const &e) {
186
        is_exception_caught = true;
187
        std::int64_t info = e.info();
188
        if (info < 0) {
189
            error_msg << "Parameter number " << -info
190
                      << " had an illegal value.";
191
        }
192
        else if (info == scratchpad_size && e.detail() != 0) {
193
            error_msg
194
                << "Insufficient scratchpad size. Required size is at least "
195
                << e.detail();
196
        }
197
        else {
198
            error_msg << "Unexpected MKL exception caught during getrf_batch() "
199
                         "or getrs_batch() call:\nreason: "
200
                      << e.what() << "\ninfo: " << e.info();
201
        }
202
    } catch (sycl::exception const &e) {
203
        is_exception_caught = true;
204
        error_msg << "Unexpected SYCL exception caught during getrf() or "
205
                     "getrs() call:\n"
206
                  << e.what();
207
    }
208
#else
209
    const std::int64_t a_size = n * n;
×
210
    const std::int64_t b_size = n * nrhs;
×
211

212
    // Get the number of independent linear streams
213
    const std::int64_t n_linear_streams =
×
214
        (batch_size > 16) ? 4 : ((batch_size > 4 ? 2 : 1));
×
215

216
    scratchpad_size =
×
217
        mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb);
×
218

219
    scratchpad = helper::alloc_scratchpad_batch<T>(scratchpad_size,
×
220
                                                   n_linear_streams, exec_q);
×
221

222
    try {
×
223
        ipiv = helper::alloc_ipiv_batch<T>(n, n_linear_streams, exec_q);
×
224
    } catch (const std::exception &e) {
×
225
        if (scratchpad != nullptr)
×
226
            sycl_free_noexcept(scratchpad, exec_q);
×
227
        throw;
×
228
    }
×
229

230
    // Computation events to manage dependencies for each linear stream
231
    std::vector<std::vector<sycl::event>> comp_evs(n_linear_streams, depends);
×
232

233
    // Release GIL to avoid serialization of host task
234
    // submissions to the same queue in OneMKL
235
    py::gil_scoped_release release;
×
236

237
    for (std::int64_t batch_id = 0; batch_id < batch_size; ++batch_id) {
×
238
        T *a_batch = a + batch_id * a_size;
×
239
        T *b_batch = b + batch_id * b_size;
×
240

241
        std::int64_t stream_id = (batch_id % n_linear_streams);
×
242

243
        T *current_scratch_gesv = scratchpad + stream_id * scratchpad_size;
×
244
        std::int64_t *current_ipiv = ipiv + stream_id * n;
×
245

246
        // Get the event dependencies for the current stream
247
        const auto &current_dep = comp_evs[stream_id];
×
248

249
        sycl::event gesv_event;
×
250

251
        try {
×
252
            gesv_event = mkl_lapack::gesv(
×
253
                exec_q,
×
254
                n,    // The order of the square matrix A
×
255
                      // and the number of rows in matrix B (0 ≤ n).
256
                nrhs, // The number of right-hand sides,
×
257
                      // i.e., the number of columns in matrix B (0 ≤ nrhs).
258
                a_batch, // Pointer to the square coefficient matrix A (n x n).
×
259
                lda, // The leading dimension of a, must be at least max(1, n).
×
260
                current_ipiv, // The pivot indices that define the permutation
×
261
                              // matrix P; row i of the matrix was interchanged
262
                              // with row ipiv(i), must be at least max(1, n).
263
                b_batch, // Pointer to the right hand side matrix B (n x nrhs).
×
264
                ldb,     // The leading dimension of matrix B,
×
265
                         // must be at least max(1, n).
266
                current_scratch_gesv, // Pointer to scratchpad memory to be used
×
267
                                      // by MKL routine for storing intermediate
268
                                      // results.
269
                scratchpad_size, current_dep);
×
270
        } catch (mkl_lapack::exception const &e) {
×
271
            is_exception_caught = true;
×
272
            gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size,
×
273
                                          scratchpad, ipiv, e, error_msg);
×
274
        } catch (sycl::exception const &e) {
×
275
            is_exception_caught = true;
×
276
            error_msg
×
277
                << "Unexpected SYCL exception caught during gesv() call:\n"
×
278
                << e.what();
×
279
        }
×
280

281
        // Update the event dependencies for the current stream
282
        comp_evs[stream_id] = {gesv_event};
×
283
    }
×
284
#endif // USE_ONEMATH
×
285

286
    if (is_exception_caught) // an unexpected error occurs
×
287
    {
×
288
        if (scratchpad != nullptr)
×
289
            sycl_free_noexcept(scratchpad, exec_q);
×
290
        if (ipiv != nullptr)
×
291
            sycl_free_noexcept(ipiv, exec_q);
×
292
        throw std::runtime_error(error_msg.str());
×
293
    }
×
294

295
    sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) {
×
296
#if defined(USE_ONEMATH)
297
        cgh.depends_on(comp_event);
298
#else
299
        for (const auto &ev : comp_evs) {
×
300
            cgh.depends_on(ev);
×
301
        }
×
302
#endif // USE_ONEMATH
×
303
        auto ctx = exec_q.get_context();
×
304
        cgh.host_task([ctx, scratchpad, ipiv]() {
×
305
            sycl_free_noexcept(scratchpad, ctx);
×
306
            sycl_free_noexcept(ipiv, ctx);
×
307
        });
×
308
    });
×
309

310
    return ht_ev;
×
311
}
×
312

313
std::pair<sycl::event, sycl::event>
314
    gesv_batch(sycl::queue &exec_q,
315
               const dpctl::tensor::usm_ndarray &coeff_matrix,
316
               const dpctl::tensor::usm_ndarray &dependent_vals,
317
               const std::vector<sycl::event> &depends)
318
{
×
319
    const int coeff_matrix_nd = coeff_matrix.get_ndim();
×
320
    const int dependent_vals_nd = dependent_vals.get_ndim();
×
321

322
    const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw();
×
323
    const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw();
×
324

325
    constexpr int expected_coeff_matrix_ndim = 3;
×
326
    constexpr int min_dependent_vals_ndim = 2;
×
327
    constexpr int max_dependent_vals_ndim = 3;
×
328

329
    gesv_utils::common_gesv_checks(
×
330
        exec_q, coeff_matrix, dependent_vals, coeff_matrix_shape,
×
331
        dependent_vals_shape, expected_coeff_matrix_ndim,
×
332
        min_dependent_vals_ndim, max_dependent_vals_ndim);
×
333

334
    // Ensure `batch_size`, `n` and 'nrhs' are non-zero, otherwise return empty
335
    // events
336
    if (helper::check_zeros_shape(coeff_matrix_nd, coeff_matrix_shape) ||
×
337
        helper::check_zeros_shape(dependent_vals_nd, dependent_vals_shape))
×
338
    {
×
339
        // nothing to do
340
        return std::make_pair(sycl::event(), sycl::event());
×
341
    }
×
342

343
    if (dependent_vals_nd == 2) {
×
344
        if (coeff_matrix_shape[2] != dependent_vals_shape[1]) {
×
345
            throw py::value_error(
×
346
                "The batch_size of "
×
347
                " coeff_matrix and dependent_vals must be"
×
348
                " the same, but got " +
×
349
                std::to_string(coeff_matrix_shape[2]) + " and " +
×
350
                std::to_string(dependent_vals_shape[1]) + ".");
×
351
        }
×
352
    }
×
353
    else if (dependent_vals_nd == 3) {
×
354
        if (coeff_matrix_shape[2] != dependent_vals_shape[2]) {
×
355
            throw py::value_error(
×
356
                "The batch_size of "
×
357
                " coeff_matrix and dependent_vals must be"
×
358
                " the same, but got " +
×
359
                std::to_string(coeff_matrix_shape[2]) + " and " +
×
360
                std::to_string(dependent_vals_shape[2]) + ".");
×
361
        }
×
362
    }
×
363

364
    auto array_types = dpctl_td_ns::usm_ndarray_types();
×
365
    const int coeff_matrix_type_id =
×
366
        array_types.typenum_to_lookup_id(coeff_matrix.get_typenum());
×
367

368
    gesv_batch_impl_fn_ptr_t gesv_batch_fn =
×
369
        gesv_batch_dispatch_vector[coeff_matrix_type_id];
×
370
    if (gesv_batch_fn == nullptr) {
×
371
        throw py::value_error(
×
372
            "No gesv implementation defined for the provided type "
×
373
            "of the coefficient matrix.");
×
374
    }
×
375

376
    char *coeff_matrix_data = coeff_matrix.get_data();
×
377
    char *dependent_vals_data = dependent_vals.get_data();
×
378

379
    const std::int64_t batch_size = coeff_matrix_shape[2];
×
380
    const std::int64_t n = coeff_matrix_shape[1];
×
381
    const std::int64_t nrhs =
×
382
        (dependent_vals_nd > 2) ? dependent_vals_shape[1] : 1;
×
383

384
    sycl::event gesv_ev;
×
385

386
#if defined(USE_ONEMATH)
387
    auto const &coeff_matrix_strides = coeff_matrix.get_strides_vector();
388
    auto const &dependent_vals_strides = dependent_vals.get_strides_vector();
389

390
    // Get the strides for the batch matrices.
391
    // Since the matrices are stored in F-contiguous order,
392
    // the stride between batches is the last element in the strides vector.
393
    const std::int64_t coeff_matrix_batch_stride = coeff_matrix_strides.back();
394
    const std::int64_t dependent_vals_batch_stride =
395
        dependent_vals_strides.back();
396

397
    gesv_ev =
398
        gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_batch_stride,
399
                      dependent_vals_batch_stride, coeff_matrix_data,
400
                      dependent_vals_data, depends);
401
#else
402
    gesv_ev = gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data,
×
403
                            dependent_vals_data, depends);
×
404
#endif // USE_ONEMATH
×
405

406
    sycl::event ht_ev = dpctl::utils::keep_args_alive(
×
407
        exec_q, {coeff_matrix, dependent_vals}, {gesv_ev});
×
408

409
    return std::make_pair(ht_ev, gesv_ev);
×
410
}
×
411

412
template <typename fnT, typename T>
413
struct GesvBatchContigFactory
414
{
415
    fnT get()
416
    {
28✔
417
        if constexpr (types::GesvTypePairSupportFactory<T>::is_defined) {
28✔
418
            return gesv_batch_impl<T>;
8✔
419
        }
420
        else {
20✔
421
            return nullptr;
20✔
422
        }
20✔
423
    }
28✔
424
};
425

426
void init_gesv_batch_dispatch_vector(void)
427
{
2✔
428
    dpctl_td_ns::DispatchVectorBuilder<gesv_batch_impl_fn_ptr_t,
2✔
429
                                       GesvBatchContigFactory,
2✔
430
                                       dpctl_td_ns::num_types>
2✔
431
        contig;
2✔
432
    contig.populate_dispatch_vector(gesv_batch_dispatch_vector);
2✔
433
}
2✔
434
} // namespace dpnp::extensions::lapack
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc