• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpnp / 10110337260

26 Jul 2024 11:14AM UTC coverage: 54.771% (+0.2%) from 54.562%
10110337260

push

github

web-flow
Implement `syevd_batch` and `heevd_batch`  (#1936)

* Implement syevd_batch and heevd_batch

* Move include dpctl type_utils header to sourse files

* Add memory alocation check for scratchpad

* Add more checks for scratchpad_size

* Move includes

* Allocate memory for w with expected shape

* Applied review comments

* Add common_evd_checks to reduce dublicate code

* Remove host_task_events from syevd and heevd

* Applied review comments

* Use init_evd_dispatch_table instead of init_evd_batch_dispatch_table

* Move init_evd_dispatch_table to evd_common_utils.hpp

* Add helper function check_zeros_shape

* Implement alloc_scratchpad function to evd_batch_common.hpp

* Make round_up_mult as inline

* Add comment for check_zeros_shape

* Make alloc_scratchpad as inline

3603 of 11172 branches covered (32.25%)

Branch coverage included in aggregate %.

283 of 373 new or added lines in 10 files covered. (75.87%)

1 existing line in 1 file now uncovered.

13501 of 20056 relevant lines covered (67.32%)

16645.33 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.03
/dpnp/backend/extensions/lapack/heevd_batch.cpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
//   this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
//   this list of conditions and the following disclaimer in the documentation
11
//   and/or other materials provided with the distribution.
12
//
13
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
//*****************************************************************************
25

26
#include <pybind11/stl.h>
27

28
#include "common_helpers.hpp"
29
#include "evd_batch_common.hpp"
30
#include "heevd_batch.hpp"
31

32
// dpctl tensor headers
33
#include "utils/type_utils.hpp"
34

35
namespace dpnp::extensions::lapack
36
{
37
namespace mkl_lapack = oneapi::mkl::lapack;
38
namespace type_utils = dpctl::tensor::type_utils;
39

40
template <typename T, typename RealT>
41
static sycl::event heevd_batch_impl(sycl::queue &exec_q,
42
                                    const oneapi::mkl::job jobz,
43
                                    const oneapi::mkl::uplo upper_lower,
44
                                    const std::int64_t batch_size,
45
                                    const std::int64_t n,
46
                                    char *in_a,
47
                                    char *out_w,
48
                                    const std::vector<sycl::event> &depends)
49
{
24✔
50
    type_utils::validate_type_for_device<T>(exec_q);
24✔
51
    type_utils::validate_type_for_device<RealT>(exec_q);
24✔
52

53
    T *a = reinterpret_cast<T *>(in_a);
24✔
54
    RealT *w = reinterpret_cast<RealT *>(out_w);
24✔
55

56
    const std::int64_t a_size = n * n;
24✔
57
    const std::int64_t w_size = n;
24✔
58

59
    const std::int64_t lda = std::max<size_t>(1UL, n);
24✔
60

61
    // Get the number of independent linear streams
62
    const std::int64_t n_linear_streams =
24✔
63
        (batch_size > 16) ? 4 : ((batch_size > 4 ? 2 : 1));
24!
64

65
    const std::int64_t scratchpad_size =
24✔
66
        mkl_lapack::heevd_scratchpad_size<T>(exec_q, jobz, upper_lower, n, lda);
24✔
67

68
    T *scratchpad =
24✔
69
        evd::alloc_scratchpad<T>(scratchpad_size, n_linear_streams, exec_q);
24✔
70

71
    // Computation events to manage dependencies for each linear stream
72
    std::vector<std::vector<sycl::event>> comp_evs(n_linear_streams, depends);
24✔
73

74
    std::stringstream error_msg;
24✔
75
    std::int64_t info = 0;
24✔
76

77
    // Release GIL to avoid serialization of host task
78
    // submissions to the same queue in OneMKL
79
    py::gil_scoped_release release;
24✔
80

81
    for (std::int64_t batch_id = 0; batch_id < batch_size; ++batch_id) {
88✔
82
        T *a_batch = a + batch_id * a_size;
64✔
83
        RealT *w_batch = w + batch_id * w_size;
64✔
84

85
        std::int64_t stream_id = (batch_id % n_linear_streams);
64✔
86

87
        T *current_scratch_heevd = scratchpad + stream_id * scratchpad_size;
64✔
88

89
        // Get the event dependencies for the current stream
90
        const auto &current_dep = comp_evs[stream_id];
64✔
91

92
        sycl::event heevd_event;
64✔
93
        try {
64✔
94
            heevd_event = mkl_lapack::heevd(
64✔
95
                exec_q,
64✔
96
                jobz, // 'jobz == job::vec' means eigenvalues and eigenvectors
64✔
97
                      // are computed.
98
                upper_lower, // 'upper_lower == job::upper' means the upper
64✔
99
                             // triangular part of A, or the lower triangular
100
                             // otherwise
101
                n,           // The order of the matrix A (0 <= n)
64✔
102
                a_batch,     // Pointer to the square A (n x n)
64✔
103
                             // If 'jobz == job::vec', then on exit it will
104
                             // contain the eigenvectors of A
105
                lda, // The leading dimension of A, must be at least max(1, n)
64✔
106
                w_batch, // Pointer to array of size at least n, it will contain
64✔
107
                         // the eigenvalues of A in ascending order
108
                current_scratch_heevd, // Pointer to scratchpad memory to be
64✔
109
                                       // used by MKL routine for storing
110
                                       // intermediate results
111
                scratchpad_size, current_dep);
64✔
112
        } catch (mkl_lapack::exception const &e) {
64✔
NEW
113
            error_msg << "Unexpected MKL exception caught during heevd() "
×
NEW
114
                         "call:\nreason: "
×
NEW
115
                      << e.what() << "\ninfo: " << e.info();
×
NEW
116
            info = e.info();
×
NEW
117
        } catch (sycl::exception const &e) {
×
NEW
118
            error_msg
×
NEW
119
                << "Unexpected SYCL exception caught during heevd() call:\n"
×
NEW
120
                << e.what();
×
NEW
121
            info = -1;
×
NEW
122
        }
×
123

124
        // Update the event dependencies for the current stream
125
        comp_evs[stream_id] = {heevd_event};
64✔
126
    }
64✔
127

128
    if (info != 0) // an unexpected error occurs
24!
NEW
129
    {
×
NEW
130
        if (scratchpad != nullptr) {
×
NEW
131
            sycl::free(scratchpad, exec_q);
×
NEW
132
        }
×
NEW
133
        throw std::runtime_error(error_msg.str());
×
NEW
134
    }
×
135

136
    sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) {
24✔
137
        for (const auto &ev : comp_evs) {
24✔
138
            cgh.depends_on(ev);
24✔
139
        }
24✔
140
        auto ctx = exec_q.get_context();
24✔
141
        cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
24✔
142
    });
24✔
143

144
    return ht_ev;
24✔
145
}
24✔
146

147
template <typename fnT, typename T, typename RealT>
148
struct HeevdBatchContigFactory
149
{
150
    fnT get()
151
    {
196✔
152
        if constexpr (types::HeevdTypePairSupportFactory<T, RealT>::is_defined)
153
        {
2✔
154
            return heevd_batch_impl<T, RealT>;
2✔
155
        }
156
        else {
194✔
157
            return nullptr;
194✔
158
        }
194✔
159
    }
196✔
160
};
161

162
using evd::evd_batch_impl_fn_ptr_t;
163

164
void init_heevd_batch(py::module_ m)
165
{
1✔
166
    using arrayT = dpctl::tensor::usm_ndarray;
1✔
167
    using event_vecT = std::vector<sycl::event>;
1✔
168

169
    static evd_batch_impl_fn_ptr_t
1✔
170
        heevd_batch_dispatch_table[dpctl_td_ns::num_types]
1✔
171
                                  [dpctl_td_ns::num_types];
1✔
172

173
    {
1✔
174
        evd::init_evd_dispatch_table<evd_batch_impl_fn_ptr_t,
1✔
175
                                     HeevdBatchContigFactory>(
1✔
176
            heevd_batch_dispatch_table);
1✔
177

178
        auto heevd_batch_pyapi =
1✔
179
            [&](sycl::queue &exec_q, const std::int8_t jobz,
1✔
180
                const std::int8_t upper_lower, arrayT &eig_vecs,
1✔
181
                arrayT &eig_vals, const event_vecT &depends = {}) {
24✔
182
                return evd::evd_batch_func(exec_q, jobz, upper_lower, eig_vecs,
24✔
183
                                           eig_vals, depends,
24✔
184
                                           heevd_batch_dispatch_table);
24✔
185
            };
24✔
186
        m.def(
1✔
187
            "_heevd_batch", heevd_batch_pyapi,
1✔
188
            "Call `heevd` from OneMKL LAPACK library in a loop to return "
1✔
189
            "the eigenvalues and eigenvectors of a batch of complex Hermitian "
1✔
190
            "matrices",
1✔
191
            py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"),
1✔
192
            py::arg("eig_vecs"), py::arg("eig_vals"),
1✔
193
            py::arg("depends") = py::list());
1✔
194
    }
1✔
195
}
1✔
196
} // namespace dpnp::extensions::lapack
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc