• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpnp / 16126000227

07 Jul 2025 07:24PM UTC coverage: 22.684% (-49.4%) from 72.051%
16126000227

Pull #2519

github

web-flow
Merge bd753a3a3 into 624f14f20
Pull Request #2519: tmp changes

889 of 9756 branches covered (9.11%)

Branch coverage included in aggregate %.

6317 of 22011 relevant lines covered (28.7%)

35.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

5.93
/dpnp/backend/extensions/lapack/gesvd_batch.cpp
1
//*****************************************************************************
2
// Copyright (c) 2023-2025, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
//   this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
//   this list of conditions and the following disclaimer in the documentation
11
//   and/or other materials provided with the distribution.
12
//
13
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
//*****************************************************************************
25

26
#include <stdexcept>
27

28
#include <pybind11/pybind11.h>
29

30
// dpctl tensor headers
31
#include "utils/type_utils.hpp"
32

33
#include "common_helpers.hpp"
34
#include "gesvd.hpp"
35
#include "gesvd_common_utils.hpp"
36
#include "types_matrix.hpp"
37

38
namespace dpnp::extensions::lapack
39
{
40
namespace mkl_lapack = oneapi::mkl::lapack;
41
namespace py = pybind11;
42
namespace type_utils = dpctl::tensor::type_utils;
43

44
typedef sycl::event (*gesvd_batch_impl_fn_ptr_t)(
45
    sycl::queue &,
46
    const oneapi::mkl::jobsvd,
47
    const oneapi::mkl::jobsvd,
48
    const std::int64_t,
49
    const std::int64_t,
50
    const std::int64_t,
51
    char *,
52
    const std::int64_t,
53
    char *,
54
    char *,
55
    const std::int64_t,
56
    char *,
57
    const std::int64_t,
58
    const std::vector<sycl::event> &);
59

60
static gesvd_batch_impl_fn_ptr_t
61
    gesvd_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
62

63
template <typename T, typename RealT>
64
static sycl::event gesvd_batch_impl(sycl::queue &exec_q,
65
                                    const oneapi::mkl::jobsvd jobu,
66
                                    const oneapi::mkl::jobsvd jobvt,
67
                                    const std::int64_t m,
68
                                    const std::int64_t n,
69
                                    const std::int64_t batch_size,
70
                                    char *in_a,
71
                                    const std::int64_t lda,
72
                                    char *out_s,
73
                                    char *out_u,
74
                                    const std::int64_t ldu,
75
                                    char *out_vt,
76
                                    const std::int64_t ldvt,
77
                                    const std::vector<sycl::event> &depends)
78
{
×
79
    type_utils::validate_type_for_device<T>(exec_q);
×
80
    type_utils::validate_type_for_device<RealT>(exec_q);
×
81

82
    T *a = reinterpret_cast<T *>(in_a);
×
83
    RealT *s = reinterpret_cast<RealT *>(out_s);
×
84
    T *u = reinterpret_cast<T *>(out_u);
×
85
    T *vt = reinterpret_cast<T *>(out_vt);
×
86

87
    const std::int64_t k = std::min(m, n);
×
88

89
    const std::int64_t a_size = m * n;
×
90
    const std::int64_t s_size = k;
×
91

92
    std::int64_t u_size = 0;
×
93
    std::int64_t vt_size = 0;
×
94

95
    if (jobu == oneapi::mkl::jobsvd::somevec ||
×
96
        jobu == oneapi::mkl::jobsvd::vectorsina)
×
97
    {
×
98
        u_size = m * k;
×
99
        vt_size = k * n;
×
100
    }
×
101
    else if (jobu == oneapi::mkl::jobsvd::vectors) {
×
102
        u_size = m * m;
×
103
        vt_size = n * n;
×
104
    }
×
105
    else if (jobu == oneapi::mkl::jobsvd::novec) {
×
106
        u_size = 0;
×
107
        vt_size = 0;
×
108
    }
×
109

110
    // Get the number of independent linear streams
111
    const std::int64_t n_linear_streams =
×
112
        (batch_size > 16) ? 4 : ((batch_size > 4 ? 2 : 1));
×
113

114
    const std::int64_t scratchpad_size = mkl_lapack::gesvd_scratchpad_size<T>(
×
115
        exec_q, jobu, jobvt, m, n, lda, ldu, ldvt);
×
116

117
    T *scratchpad = helper::alloc_scratchpad_batch<T>(scratchpad_size,
×
118
                                                      n_linear_streams, exec_q);
×
119

120
    // Computation events to manage dependencies for each linear stream
121
    std::vector<std::vector<sycl::event>> comp_evs(n_linear_streams, depends);
×
122

123
    std::stringstream error_msg;
×
124
    bool is_exception_caught = false;
×
125

126
    // Release GIL to avoid serialization of host task
127
    // submissions to the same queue in OneMKL
128
    py::gil_scoped_release release;
×
129

130
    for (std::int64_t batch_id = 0; batch_id < batch_size; ++batch_id) {
×
131

132
        T *a_batch = a + batch_id * a_size;
×
133
        T *u_batch = u + batch_id * u_size;
×
134
        RealT *s_batch = s + batch_id * s_size;
×
135
        T *vt_batch = vt + batch_id * vt_size;
×
136

137
        std::int64_t stream_id = (batch_id % n_linear_streams);
×
138

139
        T *current_scratch_gesvd = scratchpad + stream_id * scratchpad_size;
×
140

141
        // Get the event dependencies for the current stream
142
        const auto &current_dep = comp_evs[stream_id];
×
143

144
        sycl::event gesvd_event;
×
145
        try {
×
146
            gesvd_event = mkl_lapack::gesvd(
×
147
                exec_q,
×
148
                jobu,  // Character specifying how to compute the matrix U:
×
149
                       // 'A' computes all columns of U,
150
                       // 'S' computes the first min(m, n) columns of U,
151
                       // 'O' overwrites A with the columns of U,
152
                       // 'N' does not compute U.
153
                jobvt, // Character specifying how to compute the matrix VT:
×
154
                       // 'A' computes all rows of VT,
155
                       // 'S' computes the first min(m, n) rows of VT,
156
                       // 'O' overwrites A with the rows of VT,
157
                       // 'N' does not compute VT.
158
                m, // The number of rows in the input batch matrix A (0 <= m).
×
159
                n, // The number of columns in the input batch matrix A (0 <=
×
160
                   // n).
161
                a_batch, // Pointer to the input batch matrix A of size (m x n)
×
162
                         // for the current batch.
163
                lda, // The leading dimension of A, must be at least max(1, m).
×
164
                s_batch, // Pointer to the array containing the singular values
×
165
                         // for the current batch.
166
                u_batch, // Pointer to the matrix U in the singular value
×
167
                         // decomposition for the current batch.
168
                ldu, // The leading dimension of U, must be at least max(1, m).
×
169
                vt_batch, // Pointer to the matrix VT in the singular value
×
170
                          // decomposition for the current batch.
171
                ldvt, // The leading dimension of VT, must be at least max(1,
×
172
                      // n).
173
                current_scratch_gesvd, // Pointer to scratchpad memory to be
×
174
                                       // used by MKL routine for storing
175
                                       // intermediate results.
176
                scratchpad_size, current_dep);
×
177
        } catch (mkl_lapack::exception const &e) {
×
178
            is_exception_caught = true;
×
179
            gesvd_utils::handle_lapack_exc(scratchpad_size, e, error_msg);
×
180
        } catch (sycl::exception const &e) {
×
181
            is_exception_caught = true;
×
182
            error_msg
×
183
                << "Unexpected SYCL exception caught during gesvd() call:\n"
×
184
                << e.what();
×
185
        }
×
186

187
        // Update the event dependencies for the current stream
188
        comp_evs[stream_id] = {gesvd_event};
×
189
    }
×
190

191
    if (is_exception_caught) // an unexpected error occurs
×
192
    {
×
193
        if (scratchpad != nullptr) {
×
194
            dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, exec_q);
×
195
        }
×
196
        throw std::runtime_error(error_msg.str());
×
197
    }
×
198

199
    sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) {
×
200
        for (const auto &ev : comp_evs) {
×
201
            cgh.depends_on(ev);
×
202
        }
×
203
        auto ctx = exec_q.get_context();
×
204
        cgh.host_task([ctx, scratchpad]() {
×
205
            dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, ctx);
×
206
        });
×
207
    });
×
208

209
    return ht_ev;
×
210
}
×
211

212
std::pair<sycl::event, sycl::event>
213
    gesvd_batch(sycl::queue &exec_q,
214
                const std::int8_t jobu_val,
215
                const std::int8_t jobvt_val,
216
                const dpctl::tensor::usm_ndarray &a_array,
217
                const dpctl::tensor::usm_ndarray &out_s,
218
                const dpctl::tensor::usm_ndarray &out_u,
219
                const dpctl::tensor::usm_ndarray &out_vt,
220
                const std::vector<sycl::event> &depends)
221
{
×
222
    constexpr int expected_a_u_vt_ndim = 3;
×
223
    constexpr int expected_s_ndim = 2;
×
224

225
    gesvd_utils::common_gesvd_checks(exec_q, a_array, out_s, out_u, out_vt,
×
226
                                     jobu_val, jobvt_val, expected_a_u_vt_ndim,
×
227
                                     expected_s_ndim);
×
228

229
    // Ensure `batch_size`, `m` and 'n' are non-zero, otherwise return empty
230
    // events
231
    if (gesvd_utils::check_zeros_shape_gesvd(a_array, out_s, out_u, out_vt,
×
232
                                             jobu_val, jobvt_val))
×
233
    {
×
234
        // nothing to do
235
        return std::make_pair(sycl::event(), sycl::event());
×
236
    }
×
237

238
    auto array_types = dpctl_td_ns::usm_ndarray_types();
×
239
    const int a_array_type_id =
×
240
        array_types.typenum_to_lookup_id(a_array.get_typenum());
×
241
    const int out_s_type_id =
×
242
        array_types.typenum_to_lookup_id(out_s.get_typenum());
×
243

244
    gesvd_batch_impl_fn_ptr_t gesvd_batch_fn =
×
245
        gesvd_batch_dispatch_table[a_array_type_id][out_s_type_id];
×
246
    if (gesvd_batch_fn == nullptr) {
×
247
        throw py::value_error(
×
248
            "No gesvd implementation is defined for the given pair "
×
249
            "of array type and output singular values type.");
×
250
    }
×
251

252
    char *a_array_data = a_array.get_data();
×
253
    char *out_s_data = out_s.get_data();
×
254
    char *out_u_data = out_u.get_data();
×
255
    char *out_vt_data = out_vt.get_data();
×
256

257
    const py::ssize_t *a_array_shape = a_array.get_shape_raw();
×
258

259
    // Input array have (m, n, batch_size) shape
260
    const std::int64_t batch_size = a_array_shape[2];
×
261
    const std::int64_t m = a_array_shape[0];
×
262
    const std::int64_t n = a_array_shape[1];
×
263

264
    const std::int64_t lda = std::max<size_t>(1UL, m);
×
265
    const std::int64_t ldu = std::max<size_t>(1UL, m);
×
266
    const std::int64_t ldvt =
×
267
        std::max<std::size_t>(1UL, jobvt_val == 'S' ? (m > n ? n : m) : n);
×
268

269
    const oneapi::mkl::jobsvd jobu = gesvd_utils::process_job(jobu_val);
×
270
    const oneapi::mkl::jobsvd jobvt = gesvd_utils::process_job(jobvt_val);
×
271

272
    sycl::event gesvd_ev =
×
273
        gesvd_batch_fn(exec_q, jobu, jobvt, m, n, batch_size, a_array_data, lda,
×
274
                       out_s_data, out_u_data, ldu, out_vt_data, ldvt, depends);
×
275

276
    sycl::event ht_ev = dpctl::utils::keep_args_alive(
×
277
        exec_q, {a_array, out_s, out_u, out_vt}, {gesvd_ev});
×
278

279
    return std::make_pair(ht_ev, gesvd_ev);
×
280
}
×
281

282
template <typename fnT, typename T, typename RealT>
283
struct GesvdBatchContigFactory
284
{
285
    fnT get()
286
    {
392✔
287
        if constexpr (types::GesvdTypePairSupportFactory<T, RealT>::is_defined)
288
        {
8✔
289
            return gesvd_batch_impl<T, RealT>;
8✔
290
        }
291
        else {
384✔
292
            return nullptr;
384✔
293
        }
384✔
294
    }
392✔
295
};
296

297
void init_gesvd_batch_dispatch_table(void)
298
{
2✔
299
    dpctl_td_ns::DispatchTableBuilder<gesvd_batch_impl_fn_ptr_t,
2✔
300
                                      GesvdBatchContigFactory,
2✔
301
                                      dpctl_td_ns::num_types>
2✔
302
        contig;
2✔
303
    contig.populate_dispatch_table(gesvd_batch_dispatch_table);
2✔
304
}
2✔
305
} // namespace dpnp::extensions::lapack
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc