• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpnp / 21405943964

27 Jan 2026 04:46PM UTC coverage: 80.734% (-0.4%) from 81.104%
21405943964

Pull #2747

github

web-flow
Merge 39c9dce59 into b910c9237
Pull Request #2747: Align strides with numpy

1291 of 2452 branches covered (52.65%)

Branch coverage included in aggregate %.

30 of 30 new or added lines in 7 files covered. (100.0%)

57 existing lines in 5 files now uncovered.

19691 of 23537 relevant lines covered (83.66%)

26837.09 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

68.96
/dpnp/backend/extensions/blas/gemm_batch.cpp
1
//*****************************************************************************
2
// Copyright (c) 2024, Intel Corporation
3
// All rights reserved.
4
//
5
// Redistribution and use in source and binary forms, with or without
6
// modification, are permitted provided that the following conditions are met:
7
// - Redistributions of source code must retain the above copyright notice,
8
//   this list of conditions and the following disclaimer.
9
// - Redistributions in binary form must reproduce the above copyright notice,
10
//   this list of conditions and the following disclaimer in the documentation
11
//   and/or other materials provided with the distribution.
12
// - Neither the name of the copyright holder nor the names of its contributors
13
//   may be used to endorse or promote products derived from this software
14
//   without specific prior written permission.
15
//
16
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26
// THE POSSIBILITY OF SUCH DAMAGE.
27
//*****************************************************************************
28

29
#include <stdexcept>
30

31
#include <pybind11/pybind11.h>
32

33
// utils extension header
34
#include "ext/common.hpp"
35

36
// dpctl tensor headers
37
#include "utils/memory_overlap.hpp"
38
#include "utils/output_validation.hpp"
39
#include "utils/type_utils.hpp"
40

41
#include "gemm.hpp"
42
#include "types_matrix.hpp"
43

44
namespace dpnp::extensions::blas
45
{
46
namespace mkl_blas = oneapi::mkl::blas;
47
namespace py = pybind11;
48
namespace type_utils = dpctl::tensor::type_utils;
49

50
using ext::common::init_dispatch_table;
51

52
typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
53
    sycl::queue &,
54
    const std::int64_t,
55
    const std::int64_t,
56
    const std::int64_t,
57
    const std::int64_t,
58
    const std::int64_t,
59
    const std::int64_t,
60
    const std::int64_t,
61
    const std::int64_t,
62
    const std::int64_t,
63
    const std::int64_t,
64
    oneapi::mkl::transpose,
65
    oneapi::mkl::transpose,
66
    const char *,
67
    const char *,
68
    char *,
69
    const bool,
70
    const std::vector<sycl::event> &);
71

72
static gemm_batch_impl_fn_ptr_t
73
    gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
74

75
template <typename Tab, typename Tc>
76
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
77
                                   const std::int64_t m,
78
                                   const std::int64_t n,
79
                                   const std::int64_t k,
80
                                   const std::int64_t batch_size,
81
                                   const std::int64_t lda,
82
                                   const std::int64_t ldb,
83
                                   const std::int64_t ldc,
84
                                   const std::int64_t stridea,
85
                                   const std::int64_t strideb,
86
                                   const std::int64_t stridec,
87
                                   oneapi::mkl::transpose transA,
88
                                   oneapi::mkl::transpose transB,
89
                                   const char *matrixA,
90
                                   const char *matrixB,
91
                                   char *resultC,
92
                                   const bool is_row_major,
93
                                   const std::vector<sycl::event> &depends)
94
{
4,482✔
95
    type_utils::validate_type_for_device<Tab>(exec_q);
4,482✔
96
    type_utils::validate_type_for_device<Tc>(exec_q);
4,482✔
97

98
    const Tab *a = reinterpret_cast<const Tab *>(matrixA);
4,482✔
99
    const Tab *b = reinterpret_cast<const Tab *>(matrixB);
4,482✔
100
    Tc *res = reinterpret_cast<Tc *>(resultC);
4,482✔
101

102
    std::stringstream error_msg;
4,482✔
103
    bool is_exception_caught = false;
4,482✔
104

105
    sycl::event gemm_batch_event;
4,482✔
106
    try {
4,482✔
107
        auto gemm_batch_func =
4,482✔
108
            [&](sycl::queue &q, oneapi::mkl::transpose transA,
4,482✔
109
                oneapi::mkl::transpose transB, const std::int64_t m,
4,482✔
110
                const std::int64_t n, const std::int64_t k, Tab alpha,
4,482✔
111
                const Tab *a, const std::int64_t lda,
4,482✔
112
                const std::int64_t stridea, const Tab *b,
4,482✔
113
                const std::int64_t ldb, const std::int64_t strideb, Tab beta,
4,482✔
114
                Tc *c, const std::int64_t ldc, const std::int64_t stridec,
4,482✔
115
                const std::int64_t batch_size,
4,482✔
116
                const std::vector<sycl::event> &deps) -> sycl::event {
4,482✔
117
            if (is_row_major) {
4,482!
118
                return mkl_blas::row_major::gemm_batch(
4,482✔
119
                    q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
4,482✔
120
                    strideb, beta, c, ldc, stridec, batch_size, deps);
4,482✔
121
            }
4,482✔
UNCOV
122
            else {
×
UNCOV
123
                return mkl_blas::column_major::gemm_batch(
×
UNCOV
124
                    q, transA, transB, m, n, k, alpha, a, lda, stridea, b, ldb,
×
UNCOV
125
                    strideb, beta, c, ldc, stridec, batch_size, deps);
×
UNCOV
126
            }
×
127
        };
4,482✔
128
        gemm_batch_event = gemm_batch_func(
4,482✔
129
            exec_q,
4,482✔
130
            transA,     // Defines the transpose operation for matrix A:
4,482✔
131
                        // 'N' indicates no transpose, 'T' for transpose,
132
                        // or 'C' for a conjugate transpose.
133
            transB,     // Same as transA but for matrix B.
4,482✔
134
            m,          // Number of rows in matrices A and C.
4,482✔
135
            n,          // Number of columns in matrices B and C.
4,482✔
136
            k,          // Number of columns in matrix A and rows in matrix B.
4,482✔
137
            Tab(1),     // Scaling factor for the product of matrices A and B.
4,482✔
138
            a,          // Pointer to matrix A.
4,482✔
139
            lda,        // Leading dimension of matrix A, which is the
4,482✔
140
                        // stride between successive rows (for row major
141
                        // layout).
142
            stridea,    // Stride between different A matrices.
4,482✔
143
            b,          // Pointer to matrix B.
4,482✔
144
            ldb,        // Leading dimension of matrix B, similar to lda.
4,482✔
145
            strideb,    // Stride between different B matrices.
4,482✔
146
            Tab(0),     // Scaling factor for matrix C.
4,482✔
147
            res,        // Pointer to matrix C, where the result is stored.
4,482✔
148
            ldc,        // Leading dimension of matrix C.
4,482✔
149
            stridec,    // Stride between different C matrices.
4,482✔
150
            batch_size, // Specifies the number of matrix multiply
4,482✔
151
                        // operations to perform.
152
            depends);
4,482✔
153
    } catch (oneapi::mkl::exception const &e) {
4,482✔
154
        error_msg << "Unexpected MKL exception caught during gemm_batch() "
×
155
                     "call:\nreason: "
×
156
                  << e.what();
×
157
        is_exception_caught = true;
×
158
    } catch (sycl::exception const &e) {
×
159
        error_msg
×
160
            << "Unexpected SYCL exception caught during gemm_batch() call:\n"
×
161
            << e.what();
×
162
        is_exception_caught = true;
×
163
    }
×
164

165
    if (is_exception_caught) // an unexpected error occurs
4,482!
166
    {
×
167
        throw std::runtime_error(error_msg.str());
×
168
    }
×
169

170
    return gemm_batch_event;
4,482✔
171
}
4,482✔
172

173
void standardize_strides_to_nonzero(std::vector<py::ssize_t> &strides,
174
                                    const py::ssize_t *shape)
175
{
13,446✔
176
    // When shape of an array along any particular dimension is 1, the stride
177
    // along that dimension is undefined. This function standardize the strides
178
    // by calculating the non-zero value of the strides.
179
    const std::size_t ndim = strides.size();
13,446✔
180
    const bool has_zero_stride =
13,446✔
181
        std::accumulate(strides.begin(), strides.end(), 1,
13,446✔
182
                        std::multiplies<py::ssize_t>{}) == 0;
13,446✔
183

184
    if (has_zero_stride) {
13,446✔
185
        for (std::size_t i = 0; i < ndim - 1; ++i) {
3,540✔
186
            strides[i] = strides[i] == 0
2,360✔
187
                             ? std::accumulate(shape + i + 1, shape + ndim, 1,
2,360✔
188
                                               std::multiplies<py::ssize_t>{})
911✔
189
                             : strides[i];
2,360✔
190
        }
2,360✔
191
        strides[ndim - 1] = strides[ndim - 1] == 0 ? 1 : strides[ndim - 1];
1,180✔
192
    }
1,180✔
193
}
13,446✔
194

195
void standardize_strides_to_zero(std::vector<py::ssize_t> &strides,
196
                                 const py::ssize_t *shape)
197
{
13,446✔
198
    // When shape of an array along any particular dimension is 1, the stride
199
    // along that dimension is undefined. This function standardize the strides
200
    // by defining such a stride as zero. This is because for these cases,
201
    // instead of copying the array into the additional dimension for batch
202
    // multiplication, we choose to use zero as the stride between different
203
    // matrices.  Therefore, the same array is used repeatedly.
204
    const std::size_t ndim = strides.size();
13,446✔
205

206
    for (std::size_t i = 0; i < ndim; ++i) {
53,784✔
207
        if (shape[i] <= 1) {
40,338✔
208
            strides[i] = 0;
1,431✔
209
        }
1,431✔
210
    }
40,338✔
211
}
13,446✔
212

213
std::tuple<sycl::event, sycl::event, bool>
214
    gemm_batch(sycl::queue &exec_q,
215
               const dpctl::tensor::usm_ndarray &matrixA,
216
               const dpctl::tensor::usm_ndarray &matrixB,
217
               const dpctl::tensor::usm_ndarray &resultC,
218
               const std::vector<sycl::event> &depends = {})
219
{
4,482✔
220
    const int matrixA_nd = matrixA.get_ndim();
4,482✔
221
    const int matrixB_nd = matrixB.get_ndim();
4,482✔
222
    const int resultC_nd = resultC.get_ndim();
4,482✔
223

224
    if (matrixA_nd != resultC_nd || matrixB_nd != resultC_nd) {
4,482!
225
        throw py::value_error("The given arrays have incorrect dimensions.");
×
226
    }
×
227

228
    auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
4,482✔
229
    if (overlap(matrixA, resultC)) {
4,482!
230
        throw py::value_error("Input array 1 and output array are overlapping "
×
231
                              "segments of memory");
×
232
    }
×
233
    if (overlap(matrixB, resultC)) {
4,482!
234
        throw py::value_error("Input array 2 and output array are overlapping "
×
235
                              "segments of memory");
×
236
    }
×
237

238
    if (!dpctl::utils::queues_are_compatible(
4,482!
239
            exec_q,
4,482✔
240
            {matrixA.get_queue(), matrixB.get_queue(), resultC.get_queue()}))
4,482✔
241
    {
×
242
        throw py::value_error(
×
243
            "USM allocations are not compatible with the execution queue.");
×
244
    }
×
245

246
    const py::ssize_t *a_shape = matrixA.get_shape_raw();
4,482✔
247
    const py::ssize_t *b_shape = matrixB.get_shape_raw();
4,482✔
248
    const py::ssize_t *c_shape = resultC.get_shape_raw();
4,482✔
249
    const std::int64_t m = a_shape[1];
4,482✔
250
    const std::int64_t n = b_shape[2];
4,482✔
251
    const std::int64_t k = a_shape[2];
4,482✔
252
    const std::int64_t batch_size = c_shape[0];
4,482✔
253
    if (a_shape[2] != b_shape[1]) {
4,482!
254
        throw py::value_error("The number of columns in A must be equal to "
×
255
                              "the number of rows in B.");
×
256
    }
×
257
    if (a_shape[1] != c_shape[1]) {
4,482!
258
        throw py::value_error("The number of rows in A must be equal to "
×
259
                              "the number of rows in result array.");
×
260
    }
×
261
    if (b_shape[2] != c_shape[2]) {
4,482!
262
        throw py::value_error("The number of columns in B must be equal to "
×
263
                              "the number of columns in result array.");
×
264
    }
×
265
    const std::int64_t src_nelems = batch_size * m * n;
4,482✔
266
    dpctl::tensor::validation::CheckWritable::throw_if_not_writable(resultC);
4,482✔
267
    dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(resultC,
4,482✔
268
                                                               src_nelems);
4,482✔
269

270
    std::vector<py::ssize_t> a_stride = matrixA.get_strides_vector();
4,482✔
271
    std::vector<py::ssize_t> b_stride = matrixB.get_strides_vector();
4,482✔
272
    std::vector<py::ssize_t> c_stride = resultC.get_strides_vector();
4,482✔
273
    standardize_strides_to_zero(a_stride, a_shape);
4,482✔
274
    standardize_strides_to_zero(b_stride, b_shape);
4,482✔
275
    standardize_strides_to_zero(c_stride, c_shape);
4,482✔
276
    const std::int64_t stridea = a_stride[0];
4,482✔
277
    const std::int64_t strideb = b_stride[0];
4,482✔
278
    const std::int64_t stridec = c_stride[0];
4,482✔
279

280
    standardize_strides_to_nonzero(a_stride, a_shape);
4,482✔
281
    standardize_strides_to_nonzero(b_stride, b_shape);
4,482✔
282
    standardize_strides_to_nonzero(c_stride, c_shape);
4,482✔
283
    const bool A_base_is_f_contig =
4,482✔
284
        a_stride[1] == 1 && a_stride[2] == a_shape[1];
4,482!
285
    const bool A_base_is_c_contig =
4,482✔
286
        a_stride[1] == a_shape[2] && a_stride[2] == 1;
4,482!
287
    const bool B_base_is_f_contig =
4,482✔
288
        b_stride[1] == 1 && b_stride[2] == b_shape[1];
4,482!
289
    const bool B_base_is_c_contig =
4,482✔
290
        b_stride[1] == b_shape[2] && b_stride[2] == 1;
4,482!
291
    const bool C_base_is_f_contig =
4,482✔
292
        c_stride[1] == 1 && c_stride[2] == c_shape[1];
4,482✔
293
    const bool C_base_is_c_contig =
4,482✔
294
        c_stride[1] == c_shape[2] && c_stride[2] == 1;
4,482!
295

296
    if (!A_base_is_f_contig and !A_base_is_c_contig) {
4,482!
297
        throw py::value_error("The 2D base of the first input array is not "
×
298
                              "c-contiguous nor f-contiguous.");
×
299
    }
×
300
    if (!B_base_is_f_contig and !B_base_is_c_contig) {
4,482!
301
        throw py::value_error("The 2D base of the second input array is not "
×
302
                              "c-contiguous nor f-contiguous.");
×
303
    }
×
304
    if (!C_base_is_f_contig and !C_base_is_c_contig) {
4,482!
305
        throw py::value_error("The 2D base of result array is not c-contiguous "
×
306
                              "nor f-contiguous.");
×
307
    }
×
308

309
    oneapi::mkl::transpose transA;
4,482✔
310
    oneapi::mkl::transpose transB;
4,482✔
311
    std::int64_t lda;
4,482✔
312
    std::int64_t ldb;
4,482✔
313

314
// cuBLAS supports only column-major storage
315
#if defined(USE_ONEMATH_CUBLAS)
316
    constexpr bool is_row_major = false;
317

318
    transA = A_base_is_c_contig ? oneapi::mkl::transpose::T
319
                                : oneapi::mkl::transpose::N;
320
    transB = B_base_is_c_contig ? oneapi::mkl::transpose::T
321
                                : oneapi::mkl::transpose::N;
322

323
    if (transA == oneapi::mkl::transpose::N) {
324
        lda = m;
325
    }
326
    else {
327
        lda = k;
328
    }
329
    if (transB == oneapi::mkl::transpose::N) {
330
        ldb = k;
331
    }
332
    else {
333
        ldb = n;
334
    }
335
#else
336
    bool is_row_major = true;
4,482✔
337
    if (A_base_is_f_contig && B_base_is_f_contig) {
4,482!
UNCOV
338
        is_row_major = false;
×
UNCOV
339
    }
×
340

341
    if (is_row_major) {
4,482!
342
        transA = A_base_is_f_contig ? oneapi::mkl::transpose::T
4,482!
343
                                    : oneapi::mkl::transpose::N;
4,482✔
344
        transB = B_base_is_f_contig ? oneapi::mkl::transpose::T
4,482!
345
                                    : oneapi::mkl::transpose::N;
4,482✔
346

347
        if (transA == oneapi::mkl::transpose::N) {
4,482!
348
            lda = k;
4,482✔
349
        }
4,482✔
UNCOV
350
        else {
×
UNCOV
351
            lda = m;
×
UNCOV
352
        }
×
353
        if (transB == oneapi::mkl::transpose::N) {
4,482!
354
            ldb = n;
4,482✔
355
        }
4,482✔
UNCOV
356
        else {
×
UNCOV
357
            ldb = k;
×
UNCOV
358
        }
×
359
    }
4,482✔
UNCOV
360
    else {
×
UNCOV
361
        transA = oneapi::mkl::transpose::N;
×
UNCOV
362
        transB = oneapi::mkl::transpose::N;
×
UNCOV
363
        lda = m;
×
UNCOV
364
        ldb = k;
×
UNCOV
365
    }
×
366
#endif // USE_ONEMATH_CUBLAS
4,482✔
367

368
    const std::int64_t ldc = is_row_major ? n : m;
4,482!
369

370
    const int matrixA_typenum = matrixA.get_typenum();
4,482✔
371
    const int matrixB_typenum = matrixB.get_typenum();
4,482✔
372
    const int resultC_typenum = resultC.get_typenum();
4,482✔
373

374
    if (matrixA_typenum != matrixB_typenum) {
4,482!
375
        throw py::value_error("matrixA and matrixB must be of the same type.");
×
376
    }
×
377

378
    auto array_types = dpctl_td_ns::usm_ndarray_types();
4,482✔
379
    const int matrixAB_type_id =
4,482✔
380
        array_types.typenum_to_lookup_id(matrixA_typenum);
4,482✔
381
    const int resultC_type_id =
4,482✔
382
        array_types.typenum_to_lookup_id(resultC_typenum);
4,482✔
383

384
    gemm_batch_impl_fn_ptr_t gemm_batch_fn =
4,482✔
385
        gemm_batch_dispatch_table[matrixAB_type_id][resultC_type_id];
4,482✔
386
    if (gemm_batch_fn == nullptr) {
4,482!
387
        throw py::value_error(
×
388
            "No gemm_batch implementation is available for the specified data "
×
389
            "type of the input and output arrays.");
×
390
    }
×
391

392
    const char *a_typeless_ptr = matrixA.get_data();
4,482✔
393
    const char *b_typeless_ptr = matrixB.get_data();
4,482✔
394
    char *r_typeless_ptr = resultC.get_data();
4,482✔
395

396
    sycl::event gemm_batch_ev =
4,482✔
397
        gemm_batch_fn(exec_q, m, n, k, batch_size, lda, ldb, ldc, stridea,
4,482✔
398
                      strideb, stridec, transA, transB, a_typeless_ptr,
4,482✔
399
                      b_typeless_ptr, r_typeless_ptr, is_row_major, depends);
4,482✔
400

401
    sycl::event args_ev = dpctl::utils::keep_args_alive(
4,482✔
402
        exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});
4,482✔
403

404
    return std::make_tuple(args_ev, gemm_batch_ev, is_row_major);
4,482✔
405
}
4,482✔
406

407
template <typename fnT, typename Tab, typename Tc>
408
struct GemmBatchContigFactory
409
{
410
    fnT get()
411
    {
392✔
412
        if constexpr (types::GemmBatchTypePairSupportFactory<Tab,
413
                                                             Tc>::is_defined) {
16✔
414
            return gemm_batch_impl<Tab, Tc>;
16✔
415
        }
416
        else {
376✔
417
            return nullptr;
376✔
418
        }
376✔
419
    }
392✔
420
};
421

422
void init_gemm_batch_dispatch_table(void)
423
{
2✔
424
    init_dispatch_table<gemm_batch_impl_fn_ptr_t, GemmBatchContigFactory>(
2✔
425
        gemm_batch_dispatch_table);
2✔
426
}
2✔
427
} // namespace dpnp::extensions::blas
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc