• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpnp / 16238163645

12 Jul 2025 12:46PM UTC coverage: 72.044% (-0.08%) from 72.128%
16238163645

Pull #2509

github

web-flow
Merge a45a74736 into 1a7ce2207
Pull Request #2509: using `syrk` for performing special cases of matrix multiplication

4887 of 9880 branches covered (49.46%)

Branch coverage included in aggregate %.

172 of 257 new or added lines in 9 files covered. (66.93%)

7 existing lines in 6 files now uncovered.

18293 of 22295 relevant lines covered (82.05%)

19226.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.14
/dpnp/backend/extensions/blas/gemm_batch.cpp
1
//*****************************************************************************
2
// Copyright (c) 2024-2025, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
//   this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
//   this list of conditions and the following disclaimer in the documentation
11
//   and/or other materials provided with the distribution.
12
//
13
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
//*****************************************************************************
25

26
#include <stdexcept>
27

28
#include <pybind11/pybind11.h>
29

30
// dpctl tensor headers
31
#include "utils/memory_overlap.hpp"
32
#include "utils/output_validation.hpp"
33
#include "utils/type_utils.hpp"
34

35
#include "gemm.hpp"
36
#include "types_matrix.hpp"
37

38
#include "dpnp_utils.hpp"
39

40
namespace dpnp::extensions::blas
41
{
42
namespace mkl_blas = oneapi::mkl::blas;
43
namespace py = pybind11;
44
namespace type_utils = dpctl::tensor::type_utils;
45

46
typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
47
    sycl::queue &,
48
    const std::int64_t,
49
    const std::int64_t,
50
    const std::int64_t,
51
    const std::int64_t,
52
    const std::int64_t,
53
    const std::int64_t,
54
    const std::int64_t,
55
    const std::int64_t,
56
    const std::int64_t,
57
    const std::int64_t,
58
    oneapi::mkl::transpose,
59
    oneapi::mkl::transpose,
60
    const char *,
61
    const char *,
62
    char *,
63
    const bool,
64
    const std::vector<sycl::event> &);
65

66
static gemm_batch_impl_fn_ptr_t
67
    gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
68

69
template <typename Tab, typename Tc>
70
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
71
                                   const std::int64_t m,
72
                                   const std::int64_t n,
73
                                   const std::int64_t k,
74
                                   const std::int64_t batch_size,
75
                                   const std::int64_t lda,
76
                                   const std::int64_t ldb,
77
                                   const std::int64_t ldc,
78
                                   const std::int64_t stridea,
79
                                   const std::int64_t strideb,
80
                                   const std::int64_t stridec,
81
                                   oneapi::mkl::transpose transA,
82
                                   oneapi::mkl::transpose transB,
83
                                   const char *matrixA,
84
                                   const char *matrixB,
85
                                   char *resultC,
86
                                   const bool is_row_major,
87
                                   const std::vector<sycl::event> &depends)
88
{
4,245✔
89
    type_utils::validate_type_for_device<Tab>(exec_q);
4,245✔
90
    type_utils::validate_type_for_device<Tc>(exec_q);
4,245✔
91

92
    const Tab *a = reinterpret_cast<const Tab *>(matrixA);
4,245✔
93
    const Tab *b = reinterpret_cast<const Tab *>(matrixB);
4,245✔
94
    Tc *res = reinterpret_cast<Tc *>(resultC);
4,245✔
95

96
    std::stringstream error_msg;
4,245✔
97
    bool is_exception_caught = false;
4,245✔
98

99
    sycl::event gemm_batch_event;
4,245✔
100
    try {
4,245✔
101
        auto gemm_batch_func =
4,245✔
102
            [&](sycl::queue &q, oneapi::mkl::transpose transA,
4,245✔
103
                oneapi::mkl::transpose transB, const std::int64_t m,
4,245✔
104
                const std::int64_t n, const std::int64_t k, Tab alpha,
4,245✔
105
                const Tab *a, const std::int64_t lda,
4,245✔
106
                const std::int64_t stridea, const Tab *b,
4,245✔
107
                const std::int64_t ldb, const std::int64_t strideb, Tab beta,
4,245✔
108
                Tc *c, const std::int64_t ldc, const std::int64_t stridec,
4,245✔
109
                const std::int64_t batch_size,
4,245✔
110
                const std::vector<sycl::event> &deps) -> sycl::event {
4,245✔
111
            if (is_row_major) {
4,245!
112
                return mkl_blas::row_major::gemm_batch(
4,190✔
113
                    q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
4,190✔
114
                    strideb, beta, c, ldc, stridec, batch_size, deps);
4,190✔
115
            }
4,190✔
116
            else {
55✔
117
                return mkl_blas::column_major::gemm_batch(
55✔
118
                    q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
55✔
119
                    strideb, beta, c, ldc, stridec, batch_size, deps);
55✔
120
            }
55✔
121
        };
4,245✔
122
        gemm_batch_event = gemm_batch_func(
4,245✔
123
            exec_q,
4,245✔
124
            transA,     // Defines the transpose operation for matrix A:
4,245✔
125
                        // 'N' indicates no transpose, 'T' for transpose,
126
                        // or 'C' for a conjugate transpose.
127
            transB,     // Same as transA but for matrix B.
4,245✔
128
            m,          // Number of rows in matrices A and C.
4,245✔
129
            n,          // Number of columns in matrices B and C.
4,245✔
130
            k,          // Number of columns in matrix A and rows in matrix B.
4,245✔
131
            Tab(1),     // Scaling factor for the product of matrices A and B.
4,245✔
132
            a,          // Pointer to matrix A.
4,245✔
133
            lda,        // Leading dimension of matrix A, which is the
4,245✔
134
                        // stride between successive rows (for row major
135
                        // layout).
136
            stridea,    // Stride between different A matrices.
4,245✔
137
            b,          // Pointer to matrix B.
4,245✔
138
            ldb,        // Leading dimension of matrix B, similar to lda.
4,245✔
139
            strideb,    // Stride between different B matrices.
4,245✔
140
            Tab(0),     // Scaling factor for matrix C.
4,245✔
141
            res,        // Pointer to matrix C, where the result is stored.
4,245✔
142
            ldc,        // Leading dimension of matrix C.
4,245✔
143
            stridec,    // Stride between different C matrices.
4,245✔
144
            batch_size, // Specifies the number of matrix multiply
4,245✔
145
                        // operations to perform.
146
            depends);
4,245✔
147
    } catch (oneapi::mkl::exception const &e) {
4,245✔
148
        error_msg << "Unexpected MKL exception caught during gemm_batch() "
×
149
                     "call:\nreason: "
×
150
                  << e.what();
×
151
        is_exception_caught = true;
×
152
    } catch (sycl::exception const &e) {
×
153
        error_msg
×
154
            << "Unexpected SYCL exception caught during gemm_batch() call:\n"
×
155
            << e.what();
×
156
        is_exception_caught = true;
×
157
    }
×
158

159
    if (is_exception_caught) // an unexpected error occurs
4,245!
160
    {
×
161
        throw std::runtime_error(error_msg.str());
×
162
    }
×
163

164
    return gemm_batch_event;
4,245✔
165
}
4,245✔
166

167
void standardize_strides_to_nonzero(std::vector<py::ssize_t> &strides,
168
                                    const py::ssize_t *shape)
169
{
12,735✔
170
    // When shape of an array along any particular dimension is 1, the stride
171
    // along that dimension is undefined. This function standardize the strides
172
    // by calculating the non-zero value of the strides.
173
    const std::size_t ndim = strides.size();
12,735✔
174
    const bool has_zero_stride =
12,735✔
175
        std::accumulate(strides.begin(), strides.end(), 1,
12,735✔
176
                        std::multiplies<py::ssize_t>{}) == 0;
12,735✔
177

178
    if (has_zero_stride) {
12,735✔
179
        for (std::size_t i = 0; i < ndim - 1; ++i) {
3,468✔
180
            strides[i] = strides[i] == 0
2,312✔
181
                             ? std::accumulate(shape + i + 1, shape + ndim, 1,
2,312✔
182
                                               std::multiplies<py::ssize_t>{})
911✔
183
                             : strides[i];
2,312✔
184
        }
2,312✔
185
        strides[ndim - 1] = strides[ndim - 1] == 0 ? 1 : strides[ndim - 1];
1,156✔
186
    }
1,156✔
187
}
12,735✔
188

189
void standardize_strides_to_zero(std::vector<py::ssize_t> &strides,
190
                                 const py::ssize_t *shape)
191
{
12,735✔
192
    // When shape of an array along any particular dimension is 1, the stride
193
    // along that dimension is undefined. This function standardize the strides
194
    // by defining such a stride as zero. This is because for these cases,
195
    // instead of copying the array into the additional dimension for batch
196
    // multiplication, we choose to use zero as the stride between different
197
    // matrices.  Therefore, the same array is used repeatedly.
198
    const std::size_t ndim = strides.size();
12,735✔
199

200
    for (std::size_t i = 0; i < ndim; ++i) {
50,940✔
201
        if (shape[i] <= 1) {
38,205✔
202
            strides[i] = 0;
1,407✔
203
        }
1,407✔
204
    }
38,205✔
205
}
12,735✔
206

207
std::tuple<sycl::event, sycl::event, bool>
208
    gemm_batch(sycl::queue &exec_q,
209
               const dpctl::tensor::usm_ndarray &matrixA,
210
               const dpctl::tensor::usm_ndarray &matrixB,
211
               const dpctl::tensor::usm_ndarray &resultC,
212
               const std::vector<sycl::event> &depends = {})
213
{
4,245✔
214
    const int matrixA_nd = matrixA.get_ndim();
4,245✔
215
    const int matrixB_nd = matrixB.get_ndim();
4,245✔
216
    const int resultC_nd = resultC.get_ndim();
4,245✔
217

218
    if (matrixA_nd != resultC_nd || matrixB_nd != resultC_nd) {
4,245!
219
        throw py::value_error("The given arrays have incorrect dimensions.");
×
220
    }
×
221

222
    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
4,245✔
223
    if (overlap(matrixA, resultC)) {
4,245!
224
        throw py::value_error("Input array 1 and output array are overlapping "
×
225
                              "segments of memory");
×
226
    }
×
227
    if (overlap(matrixB, resultC)) {
4,245!
228
        throw py::value_error("Input array 2 and output array are overlapping "
×
229
                              "segments of memory");
×
230
    }
×
231

232
    if (!dpctl::utils::queues_are_compatible(
4,245!
233
            exec_q,
4,245✔
234
            {matrixA.get_queue(), matrixB.get_queue(), resultC.get_queue()}))
4,245✔
235
    {
×
236
        throw py::value_error(
×
237
            "USM allocations are not compatible with the execution queue.");
×
238
    }
×
239

240
    const py::ssize_t *a_shape = matrixA.get_shape_raw();
4,245✔
241
    const py::ssize_t *b_shape = matrixB.get_shape_raw();
4,245✔
242
    const py::ssize_t *c_shape = resultC.get_shape_raw();
4,245✔
243
    const std::int64_t m = a_shape[1];
4,245✔
244
    const std::int64_t n = b_shape[2];
4,245✔
245
    const std::int64_t k = a_shape[2];
4,245✔
246
    const std::int64_t batch_size = c_shape[0];
4,245✔
247
    if (a_shape[2] != b_shape[1]) {
4,245!
248
        throw py::value_error("The number of columns in A must be equal to "
×
249
                              "the number of rows in B.");
×
250
    }
×
251
    if (a_shape[1] != c_shape[1]) {
4,245!
252
        throw py::value_error("The number of rows in A must be equal to "
×
253
                              "the number of rows in result array.");
×
254
    }
×
255
    if (b_shape[2] != c_shape[2]) {
4,245!
256
        throw py::value_error("The number of columns in B must be equal to "
×
257
                              "the number of columns in result array.");
×
258
    }
×
259
    const std::int64_t src_nelems = batch_size * m * n;
4,245✔
260
    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC);
4,245✔
261
    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC,
4,245✔
262
                                                               src_nelems);
4,245✔
263

264
    std::vector<py::ssize_t> a_stride = matrixA.get_strides_vector();
4,245✔
265
    std::vector<py::ssize_t> b_stride = matrixB.get_strides_vector();
4,245✔
266
    std::vector<py::ssize_t> c_stride = resultC.get_strides_vector();
4,245✔
267
    standardize_strides_to_zero(a_stride, a_shape);
4,245✔
268
    standardize_strides_to_zero(b_stride, b_shape);
4,245✔
269
    standardize_strides_to_zero(c_stride, c_shape);
4,245✔
270
    const std::int64_t stridea = a_stride[0];
4,245✔
271
    const std::int64_t strideb = b_stride[0];
4,245✔
272
    const std::int64_t stridec = c_stride[0];
4,245✔
273

274
    standardize_strides_to_nonzero(a_stride, a_shape);
4,245✔
275
    standardize_strides_to_nonzero(b_stride, b_shape);
4,245✔
276
    standardize_strides_to_nonzero(c_stride, c_shape);
4,245✔
277
    const bool A_base_is_f_contig =
4,245✔
278
        a_stride[1] == 1 && a_stride[2] == a_shape[1];
4,245!
279
    const bool A_base_is_c_contig =
4,245✔
280
        a_stride[1] == a_shape[2] && a_stride[2] == 1;
4,245!
281
    const bool B_base_is_f_contig =
4,245✔
282
        b_stride[1] == 1 && b_stride[2] == b_shape[1];
4,245✔
283
    const bool B_base_is_c_contig =
4,245✔
284
        b_stride[1] == b_shape[2] && b_stride[2] == 1;
4,245!
285
    const bool C_base_is_f_contig =
4,245✔
286
        c_stride[1] == 1 && c_stride[2] == c_shape[1];
4,245✔
287
    const bool C_base_is_c_contig =
4,245✔
288
        c_stride[1] == c_shape[2] && c_stride[2] == 1;
4,245!
289

290
    if (!A_base_is_f_contig and !A_base_is_c_contig) {
4,245!
291
        throw py::value_error("The 2D base of the first input array is not "
×
292
                              "c-contiguous nor f-contiguous.");
×
293
    }
×
294
    if (!B_base_is_f_contig and !B_base_is_c_contig) {
4,245!
295
        throw py::value_error("The 2D base of the second input array is not "
×
296
                              "c-contiguous nor f-contiguous.");
×
297
    }
×
298
    if (!C_base_is_f_contig and !C_base_is_c_contig) {
4,245!
299
        throw py::value_error("The 2D base of result array is not c-contiguous "
×
300
                              "nor f-contiguous.");
×
301
    }
×
302

303
    oneapi::mkl::transpose transA;
4,245✔
304
    oneapi::mkl::transpose transB;
4,245✔
305
    std::int64_t lda;
4,245✔
306
    std::int64_t ldb;
4,245✔
307

308
// cuBLAS supports only column-major storage
309
#if defined(USE_ONEMATH_CUBLAS)
310
    constexpr bool is_row_major = false;
311

312
    transA = A_base_is_c_contig ? oneapi::mkl::transpose::T
313
                                : oneapi::mkl::transpose::N;
314
    transB = B_base_is_c_contig ? oneapi::mkl::transpose::T
315
                                : oneapi::mkl::transpose::N;
316

317
    if (transA == oneapi::mkl::transpose::N) {
318
        lda = m;
319
    }
320
    else {
321
        lda = k;
322
    }
323
    if (transB == oneapi::mkl::transpose::N) {
324
        ldb = k;
325
    }
326
    else {
327
        ldb = n;
328
    }
329
#else
330
    bool is_row_major = true;
4,245✔
331
    if (A_base_is_f_contig && B_base_is_f_contig) {
4,245✔
332
        is_row_major = false;
55✔
333
    }
55✔
334

335
    if (is_row_major) {
4,245✔
336
        transA = A_base_is_f_contig ? oneapi::mkl::transpose::T
4,190✔
337
                                    : oneapi::mkl::transpose::N;
4,190✔
338
        transB = B_base_is_f_contig ? oneapi::mkl::transpose::T
4,190✔
339
                                    : oneapi::mkl::transpose::N;
4,190✔
340

341
        if (transA == oneapi::mkl::transpose::N) {
4,190✔
342
            lda = k;
3,234✔
343
        }
3,234✔
344
        else {
956✔
345
            lda = m;
956✔
346
        }
956✔
347
        if (transB == oneapi::mkl::transpose::N) {
4,190✔
348
            ldb = n;
3,295✔
349
        }
3,295✔
350
        else {
895✔
351
            ldb = k;
895✔
352
        }
895✔
353
    }
4,190✔
354
    else {
55✔
355
        transA = oneapi::mkl::transpose::N;
55✔
356
        transB = oneapi::mkl::transpose::N;
55✔
357
        lda = m;
55✔
358
        ldb = k;
55✔
359
    }
55✔
360
#endif // USE_ONEMATH_CUBLAS
4,245✔
361

362
    const std::int64_t ldc = is_row_major ? n : m;
4,245✔
363

364
    const int matrixA_typenum = matrixA.get_typenum();
4,245✔
365
    const int matrixB_typenum = matrixB.get_typenum();
4,245✔
366
    const int resultC_typenum = resultC.get_typenum();
4,245✔
367

368
    if (matrixA_typenum != matrixB_typenum) {
4,245!
369
        throw py::value_error("matrixA and matrixB must be of the same type.");
×
370
    }
×
371

372
    auto array_types = dpctl_td_ns::usm_ndarray_types();
4,245✔
373
    const int matrixAB_type_id =
4,245✔
374
        array_types.typenum_to_lookup_id(matrixA_typenum);
4,245✔
375
    const int resultC_type_id =
4,245✔
376
        array_types.typenum_to_lookup_id(resultC_typenum);
4,245✔
377

378
    gemm_batch_impl_fn_ptr_t gemm_batch_fn =
4,245✔
379
        gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id];
4,245✔
380
    if (gemm_batch_fn == nullptr) {
4,245!
381
        throw py::value_error(
×
NEW
382
            "No gemm_batch implementation is available for the specified data "
×
NEW
383
            "type of the input and output arrays.");
×
UNCOV
384
    }
×
385

386
    const char *a_typeless_ptr = matrixA.get_data();
4,245✔
387
    const char *b_typeless_ptr = matrixB.get_data();
4,245✔
388
    char *r_typeless_ptr = resultC.get_data();
4,245✔
389

390
    sycl::event gemm_batch_ev =
4,245✔
391
        gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
4,245✔
392
                      strideb, stridec, transA, transB, a_typeless_ptr,
4,245✔
393
                      b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
4,245✔
394

395
    sycl::event args_ev = dpctl::utils::keep_args_alive(
4,245✔
396
        exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
4,245✔
397

398
    return std::make_tuple(args_ev, gemm_batch_ev, is_row_major);
4,245✔
399
}
4,245✔
400

401
template <typename fnT, typename Tab, typename Tc>
402
struct GemmBatchContigFactory
403
{
404
    fnT get()
405
    {
392✔
406
        if constexpr (types::GemmBatchTypePairSupportFactory<Tab,
407
                                                             Tc>::is_defined) {
16✔
408
            return gemm_batch_impl<Tab, Tc>;
16✔
409
        }
410
        else {
376✔
411
            return nullptr;
376✔
412
        }
376✔
413
    }
392✔
414
};
415

416
void init_gemm_batch_dispatch_table(void)
417
{
2✔
418
    dpctl_td_ns::DispatchTableBuilder<gemm_batch_impl_fn_ptr_t,
2✔
419
                                      GemmBatchContigFactory,
2✔
420
                                      dpctl_td_ns::num_types>
2✔
421
        contig;
2✔
422
    contig.populate_dispatch_table(gemm_batch_dispatch_table);
2✔
423
}
2✔
424
} // namespace dpnp::extensions::blas
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc