• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpnp / 10566927269

26 Aug 2024 08:38PM UTC coverage: 58.919% (+0.007%) from 58.912%
10566927269

push

github

web-flow
Implemented BLAS backend for work with oneMKL Interfaces (#1981)

* Implemented BLAS backend for work with oneMKL Interfaces

* update gemm

* Update gemv

* Update gemv_impl

* Update gemm_batch

* Fix pre-commit

* updates to remove duplication

* fix two issues: 1) when order is given as "A" 2) when axes is given and column_major is called

---------

Co-authored-by: Vahid Tavanashad <vahid.tavanashad@intel.com>
Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>

3628 of 9322 branches covered (38.92%)

Branch coverage included in aggregate %.

72 of 76 new or added lines in 8 files covered. (94.74%)

2 existing lines in 2 files now uncovered.

13502 of 19752 relevant lines covered (68.36%)

17026.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.78
/dpnp/backend/extensions/blas/gemm_batch.cpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
//   this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
//   this list of conditions and the following disclaimer in the documentation
11
//   and/or other materials provided with the distribution.
12
//
13
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
//*****************************************************************************
25

26
#include <pybind11/pybind11.h>
27

28
// dpctl tensor headers
29
#include "utils/memory_overlap.hpp"
30
#include "utils/output_validation.hpp"
31
#include "utils/type_utils.hpp"
32

33
#include "gemm.hpp"
34
#include "types_matrix.hpp"
35

36
#include "dpnp_utils.hpp"
37

38
namespace dpnp::extensions::blas
39
{
40
namespace mkl_blas = oneapi::mkl::blas;
41
namespace py = pybind11;
42
namespace type_utils = dpctl::tensor::type_utils;
43

44
typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
45
    sycl::queue &,
46
    const std::int64_t,
47
    const std::int64_t,
48
    const std::int64_t,
49
    const std::int64_t,
50
    const std::int64_t,
51
    const std::int64_t,
52
    const std::int64_t,
53
    const std::int64_t,
54
    const std::int64_t,
55
    const std::int64_t,
56
    oneapi::mkl::transpose,
57
    oneapi::mkl::transpose,
58
    const char *,
59
    const char *,
60
    char *,
61
    const bool,
62
    const std::vector<sycl::event> &);
63

64
static gemm_batch_impl_fn_ptr_t
65
    gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
66

67
template <typename Tab, typename Tc>
68
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
69
                                   const std::int64_t m,
70
                                   const std::int64_t n,
71
                                   const std::int64_t k,
72
                                   const std::int64_t batch_size,
73
                                   const std::int64_t lda,
74
                                   const std::int64_t ldb,
75
                                   const std::int64_t ldc,
76
                                   const std::int64_t stridea,
77
                                   const std::int64_t strideb,
78
                                   const std::int64_t stridec,
79
                                   oneapi::mkl::transpose transA,
80
                                   oneapi::mkl::transpose transB,
81
                                   const char *matrixA,
82
                                   const char *matrixB,
83
                                   char *resultC,
84
                                   const bool is_row_major,
85
                                   const std::vector<sycl::event> &depends)
86
{
4,213✔
87
    type_utils::validate_type_for_device<Tab>(exec_q);
4,213✔
88
    type_utils::validate_type_for_device<Tc>(exec_q);
4,213✔
89

90
    const Tab *a = reinterpret_cast<const Tab *>(matrixA);
4,213✔
91
    const Tab *b = reinterpret_cast<const Tab *>(matrixB);
4,213✔
92
    Tc *res = reinterpret_cast<Tc *>(resultC);
4,213✔
93

94
    std::stringstream error_msg;
4,213✔
95
    bool is_exception_caught = false;
4,213✔
96

97
    sycl::event gemm_batch_event;
4,213✔
98
    try {
4,213✔
99
        auto gemm_batch_func =
4,213✔
100
            [&](sycl::queue &q, oneapi::mkl::transpose transA,
4,213✔
101
                oneapi::mkl::transpose transB, const std::int64_t m,
4,213✔
102
                const std::int64_t n, const std::int64_t k, Tab alpha,
4,213✔
103
                const Tab *a, const std::int64_t lda,
4,213✔
104
                const std::int64_t stridea, const Tab *b,
4,213✔
105
                const std::int64_t ldb, const std::int64_t strideb, Tab beta,
4,213✔
106
                Tc *c, const std::int64_t ldc, const std::int64_t stridec,
4,213✔
107
                const std::int64_t batch_size,
4,213✔
108
                const std::vector<sycl::event> &deps) -> sycl::event {
4,213✔
109
#if defined(USE_ONEMKL_CUBLAS)
110
            return mkl_blas::column_major::gemm_batch(
111
                q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
112
                strideb, beta, c, ldc, stridec, batch_size, deps);
113
#else
114
            if (is_row_major) {
4,213!
115
                return mkl_blas::row_major::gemm_batch(
4,154✔
116
                    q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
4,154✔
117
                    strideb, beta, c, ldc, stridec, batch_size, deps);
4,154✔
118
            }
4,154✔
119
            else {
59✔
120
                return mkl_blas::column_major::gemm_batch(
59✔
121
                    q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
59✔
122
                    strideb, beta, c, ldc, stridec, batch_size, deps);
59✔
123
            }
59✔
124
#endif // USE_ONEMKL_CUBLAS
4,213✔
125
        };
4,213✔
126
        gemm_batch_event = gemm_batch_func(
4,213✔
127
            exec_q,
4,213✔
128
            transA,     // Defines the transpose operation for matrix A:
4,213✔
129
                        // 'N' indicates no transpose, 'T' for transpose,
130
                        // or 'C' for a conjugate transpose.
131
            transB,     // Same as transA but for matrix B.
4,213✔
132
            m,          // Number of rows in matrices A and C.
4,213✔
133
            n,          // Number of columns in matrices B and C.
4,213✔
134
            k,          // Number of columns in matrix A and rows in matrix B.
4,213✔
135
            Tab(1),     // Scaling factor for the product of matrices A and B.
4,213✔
136
            a,          // Pointer to matrix A.
4,213✔
137
            lda,        // Leading dimension of matrix A, which is the
4,213✔
138
                        // stride between successive rows (for row major
139
                        // layout).
140
            stridea,    // Stride between different A matrices.
4,213✔
141
            b,          // Pointer to matrix B.
4,213✔
142
            ldb,        // Leading dimension of matrix B, similar to lda.
4,213✔
143
            strideb,    // Stride between different B matrices.
4,213✔
144
            Tab(0),     // Scaling factor for matrix C.
4,213✔
145
            res,        // Pointer to matrix C, where the result is stored.
4,213✔
146
            ldc,        // Leading dimension of matrix C.
4,213✔
147
            stridec,    // Stride between different C matrices.
4,213✔
148
            batch_size, // Specifies the number of matrix multiply
4,213✔
149
                        // operations to perform.
150
            depends);
4,213✔
151
    } catch (oneapi::mkl::exception const &e) {
4,213✔
152
        error_msg << "Unexpected MKL exception caught during gemm_batch() "
×
153
                     "call:\nreason: "
×
154
                  << e.what();
×
155
        is_exception_caught = true;
×
156
    } catch (sycl::exception const &e) {
×
157
        error_msg
×
158
            << "Unexpected SYCL exception caught during gemm_batch() call:\n"
×
159
            << e.what();
×
160
        is_exception_caught = true;
×
161
    }
×
162

163
    if (is_exception_caught) // an unexpected error occurs
4,213!
164
    {
×
165
        throw std::runtime_error(error_msg.str());
×
166
    }
×
167

168
    return gemm_batch_event;
4,213✔
169
}
4,213✔
170

171
void standardize_strides_to_nonzero(std::vector<py::ssize_t> &strides,
172
                                    const py::ssize_t *shape)
173
{
12,639✔
174
    // When shape of an array along any particular dimension is 1, the stride
175
    // along that dimension is undefined. This function standardize the strides
176
    // by calculating the non-zero value of the strides.
177
    const std::size_t ndim = strides.size();
12,639✔
178
    const bool has_zero_stride =
12,639✔
179
        std::accumulate(strides.begin(), strides.end(), 1,
12,639✔
180
                        std::multiplies<py::ssize_t>{}) == 0;
12,639✔
181

182
    if (has_zero_stride) {
12,639✔
183
        for (std::size_t i = 0; i < ndim - 1; ++i) {
3,480✔
184
            strides[i] = strides[i] == 0
2,320✔
185
                             ? std::accumulate(shape + i + 1, shape + ndim, 1,
2,320✔
186
                                               std::multiplies<py::ssize_t>{})
987✔
187
                             : strides[i];
2,320✔
188
        }
2,320✔
189
        strides[ndim - 1] = strides[ndim - 1] == 0 ? 1 : strides[ndim - 1];
1,160✔
190
    }
1,160✔
191
}
12,639✔
192

193
void standardize_strides_to_zero(std::vector<py::ssize_t> &strides,
194
                                 const py::ssize_t *shape)
195
{
12,639✔
196
    // When shape of an array along any particular dimension is 1, the stride
197
    // along that dimension is undefined. This function standardize the strides
198
    // by defining such a stride as zero. This is because for these cases,
199
    // instead of copying the array into the additional dimension for batch
200
    // multiplication, we choose to use zero as the stride between different
201
    // matrices.  Therefore, the same array is used repeatedly.
202
    const std::size_t ndim = strides.size();
12,639✔
203

204
    for (std::size_t i = 0; i < ndim; ++i) {
50,556✔
205
        if (shape[i] <= 1) {
37,917✔
206
            strides[i] = 0;
1,425✔
207
        }
1,425✔
208
    }
37,917✔
209
}
12,639✔
210

211
std::tuple<sycl::event, sycl::event, bool>
212
    gemm_batch(sycl::queue &exec_q,
213
               const dpctl::tensor::usm_ndarray &matrixA,
214
               const dpctl::tensor::usm_ndarray &matrixB,
215
               const dpctl::tensor::usm_ndarray &resultC,
216
               const std::vector<sycl::event> &depends = {})
217
{
4,213✔
218
    const int matrixA_nd = matrixA.get_ndim();
4,213✔
219
    const int matrixB_nd = matrixB.get_ndim();
4,213✔
220
    const int resultC_nd = resultC.get_ndim();
4,213✔
221

222
    if (matrixA_nd != resultC_nd || matrixB_nd != resultC_nd) {
4,213!
223
        throw py::value_error("The given arrays have incorrect dimensions.");
×
224
    }
×
225

226
    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
4,213✔
227
    if (overlap(matrixA, resultC)) {
4,213!
228
        throw py::value_error("Input array 1 and output array are overlapping "
×
229
                              "segments of memory");
×
230
    }
×
231
    if (overlap(matrixB, resultC)) {
4,213!
232
        throw py::value_error("Input array 2 and output array are overlapping "
×
233
                              "segments of memory");
×
234
    }
×
235

236
    if (!dpctl::utils::queues_are_compatible(
4,213!
237
            exec_q,
4,213✔
238
            {matrixA.get_queue(), matrixB.get_queue(), resultC.get_queue()}))
4,213✔
239
    {
×
240
        throw py::value_error(
×
241
            "USM allocations are not compatible with the execution queue.");
×
242
    }
×
243

244
    const py::ssize_t *a_shape = matrixA.get_shape_raw();
4,213✔
245
    const py::ssize_t *b_shape = matrixB.get_shape_raw();
4,213✔
246
    const py::ssize_t *c_shape = resultC.get_shape_raw();
4,213✔
247
    const std::int64_t m = a_shape[1];
4,213✔
248
    const std::int64_t n = b_shape[2];
4,213✔
249
    const std::int64_t k = a_shape[2];
4,213✔
250
    const std::int64_t batch_size = c_shape[0];
4,213✔
251
    if (a_shape[2] != b_shape[1]) {
4,213!
252
        throw py::value_error("The number of columns in A must be equal to "
×
253
                              "the number of rows in B.");
×
254
    }
×
255
    if (a_shape[1] != c_shape[1]) {
4,213!
256
        throw py::value_error("The number of rows in A must be equal to "
×
257
                              "the number of rows in result array.");
×
258
    }
×
259
    if (b_shape[2] != c_shape[2]) {
4,213!
260
        throw py::value_error("The number of columns in B must be equal to "
×
261
                              "the number of columns in result array.");
×
262
    }
×
263
    const std::int64_t src_nelems = batch_size * m * n;
4,213✔
264
    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC);
4,213✔
265
    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC,
4,213✔
266
                                                               src_nelems);
4,213✔
267

268
    std::vector<py::ssize_t> a_stride = matrixA.get_strides_vector();
4,213✔
269
    std::vector<py::ssize_t> b_stride = matrixB.get_strides_vector();
4,213✔
270
    std::vector<py::ssize_t> c_stride = resultC.get_strides_vector();
4,213✔
271
    standardize_strides_to_zero(a_stride, a_shape);
4,213✔
272
    standardize_strides_to_zero(b_stride, b_shape);
4,213✔
273
    standardize_strides_to_zero(c_stride, c_shape);
4,213✔
274
    const std::int64_t stridea = a_stride[0];
4,213✔
275
    const std::int64_t strideb = b_stride[0];
4,213✔
276
    const std::int64_t stridec = c_stride[0];
4,213✔
277

278
    standardize_strides_to_nonzero(a_stride, a_shape);
4,213✔
279
    standardize_strides_to_nonzero(b_stride, b_shape);
4,213✔
280
    standardize_strides_to_nonzero(c_stride, c_shape);
4,213✔
281
    const bool A_base_is_f_contig =
4,213✔
282
        a_stride[1] == 1 && a_stride[2] == a_shape[1];
4,213!
283
    const bool A_base_is_c_contig =
4,213✔
284
        a_stride[1] == a_shape[2] && a_stride[2] == 1;
4,213!
285
    const bool B_base_is_f_contig =
4,213✔
286
        b_stride[1] == 1 && b_stride[2] == b_shape[1];
4,213✔
287
    const bool B_base_is_c_contig =
4,213✔
288
        b_stride[1] == b_shape[2] && b_stride[2] == 1;
4,213!
289
    const bool C_base_is_f_contig =
4,213✔
290
        c_stride[1] == 1 && c_stride[2] == c_shape[1];
4,213✔
291
    const bool C_base_is_c_contig =
4,213✔
292
        c_stride[1] == c_shape[2] && c_stride[2] == 1;
4,213!
293

294
    if (!A_base_is_f_contig and !A_base_is_c_contig) {
4,213!
UNCOV
295
        throw py::value_error("The 2D base of the first input array is not "
×
296
                              "c-contiguous nor f-contiguous.");
×
297
    }
×
298
    if (!B_base_is_f_contig and !B_base_is_c_contig) {
4,213!
299
        throw py::value_error("The 2D base of the second input array is not "
×
300
                              "c-contiguous nor f-contiguous.");
×
301
    }
×
302
    if (!C_base_is_f_contig and !C_base_is_c_contig) {
4,213!
303
        throw py::value_error("The 2D base of result array is not c-contiguous "
×
304
                              "nor f-contiguous.");
×
305
    }
×
306

307
    oneapi::mkl::transpose transA;
4,213✔
308
    oneapi::mkl::transpose transB;
4,213✔
309
    std::int64_t lda;
4,213✔
310
    std::int64_t ldb;
4,213✔
311

312
#if defined(USE_ONEMKL_CUBLAS)
313
    const bool is_row_major = false;
314

315
    transA = A_base_is_c_contig ? oneapi::mkl::transpose::T
316
                                : oneapi::mkl::transpose::N;
317
    transB = B_base_is_c_contig ? oneapi::mkl::transpose::T
318
                                : oneapi::mkl::transpose::N;
319

320
    if (transA == oneapi::mkl::transpose::N) {
321
        lda = m;
322
    }
323
    else {
324
        lda = k;
325
    }
326
    if (transB == oneapi::mkl::transpose::N) {
327
        ldb = k;
328
    }
329
    else {
330
        ldb = n;
331
    }
332
#else
333
    bool is_row_major = true;
4,213✔
334
    if (A_base_is_f_contig && B_base_is_f_contig) {
4,213✔
335
        is_row_major = false;
59✔
336
    }
59✔
337

338
    if (is_row_major) {
4,213✔
339
        transA = A_base_is_f_contig ? oneapi::mkl::transpose::T
4,154✔
340
                                    : oneapi::mkl::transpose::N;
4,154✔
341
        transB = B_base_is_f_contig ? oneapi::mkl::transpose::T
4,154✔
342
                                    : oneapi::mkl::transpose::N;
4,154✔
343

344
        if (transA == oneapi::mkl::transpose::N) {
4,154✔
345
            lda = k;
3,220✔
346
        }
3,220✔
347
        else {
934✔
348
            lda = m;
934✔
349
        }
934✔
350
        if (transB == oneapi::mkl::transpose::N) {
4,154✔
351
            ldb = n;
3,260✔
352
        }
3,260✔
353
        else {
894✔
354
            ldb = k;
894✔
355
        }
894✔
356
    }
4,154✔
357
    else {
59✔
358
        transA = oneapi::mkl::transpose::N;
59✔
359
        transB = oneapi::mkl::transpose::N;
59✔
360
        lda = m;
59✔
361
        ldb = k;
59✔
362
    }
59✔
363
#endif // USE_ONEMKL_CUBLAS
4,213✔
364

365
    const std::int64_t ldc = is_row_major ? n : m;
4,213✔
366

367
    const int matrixA_typenum = matrixA.get_typenum();
4,213✔
368
    const int matrixB_typenum = matrixB.get_typenum();
4,213✔
369
    const int resultC_typenum = resultC.get_typenum();
4,213✔
370

371
    if (matrixA_typenum != matrixB_typenum) {
4,213!
372
        throw py::value_error("matrixA and matrixB must be of the same type.");
×
373
    }
×
374

375
    auto array_types = dpctl_td_ns::usm_ndarray_types();
4,213✔
376
    const int matrixAB_type_id =
4,213✔
377
        array_types.typenum_to_lookup_id(matrixA_typenum);
4,213✔
378
    const int resultC_type_id =
4,213✔
379
        array_types.typenum_to_lookup_id(resultC_typenum);
4,213✔
380

381
    gemm_batch_impl_fn_ptr_t gemm_batch_fn =
4,213✔
382
        gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id];
4,213✔
383
    if (gemm_batch_fn == nullptr) {
4,213!
384
        throw py::value_error(
×
385
            "Types of input matrices and result matrix are mismatched.");
×
386
    }
×
387

388
    const char *a_typeless_ptr = matrixA.get_data();
4,213✔
389
    const char *b_typeless_ptr = matrixB.get_data();
4,213✔
390
    char *r_typeless_ptr = resultC.get_data();
4,213✔
391

392
    sycl::event gemm_batch_ev =
4,213✔
393
        gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
4,213✔
394
                      strideb, stridec, transA, transB, a_typeless_ptr,
4,213✔
395
                      b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
4,213✔
396

397
    sycl::event args_ev = dpctl::utils::keep_args_alive(
4,213✔
398
        exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
4,213✔
399

400
    return std::make_tuple(args_ev, gemm_batch_ev, is_row_major);
4,213✔
401
}
4,213✔
402

403
template <typename fnT, typename Tab, typename Tc>
404
struct GemmBatchContigFactory
405
{
406
    fnT get()
407
    {
196✔
408
        if constexpr (types::GemmBatchTypePairSupportFactory<Tab,
409
                                                             Tc>::is_defined) {
8✔
410
            return gemm_batch_impl<Tab, Tc>;
8✔
411
        }
412
        else {
188✔
413
            return nullptr;
188✔
414
        }
188✔
415
    }
196✔
416
};
417

418
void init_gemm_batch_dispatch_table(void)
419
{
1✔
420
    dpctl_td_ns::DispatchTableBuilder<gemm_batch_impl_fn_ptr_t,
1✔
421
                                      GemmBatchContigFactory,
1✔
422
                                      dpctl_td_ns::num_types>
1✔
423
        contig;
1✔
424
    contig.populate_dispatch_table(gemm_batch_dispatch_table);
1✔
425
}
1✔
426
} // namespace dpnp::extensions::blas
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc