• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23339604691

20 Mar 2026 10:51AM UTC coverage: 64.115% (-0.005%) from 64.12%
23339604691

push

github

web-flow
Merge pull request #596 from daisytuner/mlir-opt-linear

[MLIR] Optimized linear layer in PyTorch frontend

34 of 69 new or added lines in 1 file covered. (49.28%)

3 existing lines in 2 files now uncovered.

26404 of 41182 relevant lines covered (64.12%)

399.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

55.85
/sdfg/src/data_flow/library_nodes/math/tensor/matmul_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/matmul_node.h"
2
#include <cstddef>
3
#include <string>
4

5
#include "sdfg/analysis/scope_analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
8
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
9
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
10
#include "sdfg/data_flow/tasklet.h"
11
#include "sdfg/element.h"
12
#include "sdfg/exceptions.h"
13
#include "sdfg/structured_control_flow/control_flow_node.h"
14
#include "sdfg/structured_control_flow/map.h"
15
#include "sdfg/structured_control_flow/sequence.h"
16
#include "sdfg/symbolic/symbolic.h"
17
#include "sdfg/types/pointer.h"
18
#include "sdfg/types/scalar.h"
19
#include "sdfg/types/tensor.h"
20
#include "sdfg/types/type.h"
21
#include "sdfg/types/utils.h"
22

23
namespace sdfg {
24
namespace math {
25
namespace tensor {
26

27
bool MatMulNode::has_basic_strides(symbolic::MultiExpression shape, symbolic::MultiExpression strides) {
8✔
28
    auto basic_strides = types::Tensor::strides_from_shape(shape);
8✔
29
    if (basic_strides.size() != strides.size()) {
8✔
NEW
30
        return false;
×
NEW
31
    }
×
32
    for (size_t i = 0; i < strides.size(); i++) {
26✔
33
        if (!symbolic::eq(basic_strides[i], strides[i])) {
18✔
NEW
34
            return false;
×
NEW
35
        }
×
36
    }
18✔
37
    return true;
8✔
38
}
8✔
39

NEW
40
bool MatMulNode::has_transposed_strides(symbolic::MultiExpression shape, symbolic::MultiExpression strides) {
×
NEW
41
    if (shape.size() < 2) {
×
NEW
42
        return false;
×
NEW
43
    }
×
NEW
44
    symbolic::MultiExpression new_shape;
×
NEW
45
    new_shape.reserve(shape.size());
×
NEW
46
    for (size_t i = 0; i < shape.size() - 2; i++) {
×
NEW
47
        new_shape.push_back(shape[i]);
×
NEW
48
    }
×
NEW
49
    new_shape.push_back(shape[shape.size() - 1]);
×
NEW
50
    new_shape.push_back(shape[shape.size() - 2]);
×
NEW
51
    symbolic::MultiExpression transposed_strides(strides);
×
NEW
52
    transposed_strides[strides.size() - 2] = strides[strides.size() - 1];
×
NEW
53
    transposed_strides[strides.size() - 1] = strides[strides.size() - 2];
×
NEW
54
    return MatMulNode::has_basic_strides(new_shape, transposed_strides);
×
NEW
55
}
×
56

57
MatMulNode::MatMulNode(
58
    size_t element_id,
59
    const DebugInfo& debug_info,
60
    const graph::Vertex vertex,
61
    data_flow::DataFlowGraph& parent,
62
    const symbolic::MultiExpression& shape_a,
63
    const symbolic::MultiExpression& shape_b,
64
    const symbolic::MultiExpression& strides_a,
65
    const symbolic::MultiExpression& strides_b,
66
    symbolic::Expression offset_a,
67
    symbolic::Expression offset_b
68
)
69
    : TensorNode(
5✔
70
          element_id,
5✔
71
          debug_info,
5✔
72
          vertex,
5✔
73
          parent,
5✔
74
          LibraryNodeType_MatMul,
5✔
75
          {"Y"},
5✔
76
          {"A", "B"},
5✔
77
          data_flow::ImplementationType_NONE
5✔
78
      ),
5✔
79
      shape_a_(shape_a), shape_b_(shape_b), strides_a_(strides_a), strides_b_(strides_b), offset_a_(offset_a),
5✔
80
      offset_b_(offset_b) {
5✔
81
    if (shape_a_.size() < 2) {
5✔
82
        throw std::invalid_argument("MatMulNode: Input A must have at least 2 dimensions");
×
83
    }
×
84
    if (shape_b_.size() < 2) {
5✔
85
        throw std::invalid_argument("MatMulNode: Input B must have at least 2 dimensions");
×
86
    }
×
87
    // Compute default row-major strides if not provided
88
    if (strides_a_.empty()) {
5✔
89
        strides_a_.resize(shape_a_.size());
5✔
90
        strides_a_[shape_a_.size() - 1] = symbolic::integer(1);
5✔
91
        for (int i = static_cast<int>(shape_a_.size()) - 2; i >= 0; --i) {
11✔
92
            strides_a_[i] = symbolic::mul(strides_a_[i + 1], shape_a_[i + 1]);
6✔
93
        }
6✔
94
    }
5✔
95
    if (strides_b_.empty()) {
5✔
96
        strides_b_.resize(shape_b_.size());
5✔
97
        strides_b_[shape_b_.size() - 1] = symbolic::integer(1);
5✔
98
        for (int i = static_cast<int>(shape_b_.size()) - 2; i >= 0; --i) {
11✔
99
            strides_b_[i] = symbolic::mul(strides_b_[i + 1], shape_b_[i + 1]);
6✔
100
        }
6✔
101
    }
5✔
102
}
5✔
103

104
symbolic::Expression MatMulNode::m() const {
8✔
105
    // M is the second-to-last dimension of A
106
    return shape_a_[shape_a_.size() - 2];
8✔
107
}
8✔
108

109
symbolic::Expression MatMulNode::n() const {
12✔
110
    // N is the last dimension of B
111
    return shape_b_[shape_b_.size() - 1];
12✔
112
}
12✔
113

114
symbolic::Expression MatMulNode::k() const {
7✔
115
    // K is the last dimension of A (and second-to-last of B)
116
    return shape_a_[shape_a_.size() - 1];
7✔
117
}
7✔
118

119
void MatMulNode::validate(const Function& function) const {
5✔
120
    TensorNode::validate(function);
5✔
121

122
    auto& graph = this->get_parent();
5✔
123

124
    // Check that we have exactly 2 inputs and 1 output
125
    if (graph.in_degree(*this) != 2) {
5✔
126
        throw InvalidSDFGException("MatMulNode: Expected exactly 2 inputs (A and B)");
×
127
    }
×
128
    if (graph.out_degree(*this) != 1) {
5✔
129
        throw InvalidSDFGException("MatMulNode: Expected exactly 1 output (Y)");
×
130
    }
×
131

132
    // Validate K dimension matches between A and B
133
    auto k_a = shape_a_[shape_a_.size() - 1];
5✔
134
    auto k_b = shape_b_[shape_b_.size() - 2];
5✔
135
    if (!symbolic::eq(k_a, k_b)) {
5✔
136
        throw InvalidSDFGException(
×
137
            "MatMulNode: K dimension mismatch. A has K=" + k_a->__str__() + ", B has K=" + k_b->__str__()
×
138
        );
×
139
    }
×
140
}
5✔
141

142
symbolic::SymbolSet MatMulNode::symbols() const {
1✔
143
    symbolic::SymbolSet syms;
1✔
144
    for (const auto& dim : shape_a_) {
2✔
145
        for (auto& atom : symbolic::atoms(dim)) {
2✔
146
            syms.insert(atom);
2✔
147
        }
2✔
148
    }
2✔
149
    for (const auto& dim : shape_b_) {
2✔
150
        for (auto& atom : symbolic::atoms(dim)) {
2✔
151
            syms.insert(atom);
2✔
152
        }
2✔
153
    }
2✔
154
    for (const auto& stride : strides_a_) {
2✔
155
        for (auto& atom : symbolic::atoms(stride)) {
2✔
156
            syms.insert(atom);
1✔
157
        }
1✔
158
    }
2✔
159
    for (const auto& stride : strides_b_) {
2✔
160
        for (auto& atom : symbolic::atoms(stride)) {
2✔
161
            syms.insert(atom);
1✔
162
        }
1✔
163
    }
2✔
164
    for (auto& atom : symbolic::atoms(offset_a_)) {
1✔
165
        syms.insert(atom);
×
166
    }
×
167
    for (auto& atom : symbolic::atoms(offset_b_)) {
1✔
168
        syms.insert(atom);
×
169
    }
×
170
    return syms;
1✔
171
}
1✔
172

173
void MatMulNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
174
    for (auto& dim : shape_a_) {
×
175
        dim = symbolic::subs(dim, old_expression, new_expression);
×
176
    }
×
177
    for (auto& dim : shape_b_) {
×
178
        dim = symbolic::subs(dim, old_expression, new_expression);
×
179
    }
×
180
    for (auto& stride : strides_a_) {
×
181
        stride = symbolic::subs(stride, old_expression, new_expression);
×
182
    }
×
183
    for (auto& stride : strides_b_) {
×
184
        stride = symbolic::subs(stride, old_expression, new_expression);
×
185
    }
×
186
    offset_a_ = symbolic::subs(offset_a_, old_expression, new_expression);
×
187
    offset_b_ = symbolic::subs(offset_b_, old_expression, new_expression);
×
188
}
×
189

190
std::unique_ptr<data_flow::DataFlowNode> MatMulNode::
191
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
192
    return std::unique_ptr<data_flow::DataFlowNode>(new MatMulNode(
×
193
        element_id, debug_info(), vertex, parent, shape_a_, shape_b_, strides_a_, strides_b_, offset_a_, offset_b_
×
194
    ));
×
195
}
×
196

197
std::string MatMulNode::toStr() const {
×
198
    std::stringstream ss;
×
199
    ss << "MatMul(";
×
200
    ss << "A=[";
×
201
    for (size_t i = 0; i < shape_a_.size(); ++i) {
×
202
        if (i > 0) ss << ", ";
×
203
        ss << shape_a_[i]->__str__();
×
204
    }
×
205
    ss << "], strides_a=[";
×
206
    for (size_t i = 0; i < strides_a_.size(); ++i) {
×
207
        if (i > 0) ss << ", ";
×
208
        ss << strides_a_[i]->__str__();
×
209
    }
×
210
    ss << "], offset_a=" << offset_a_->__str__();
×
211
    ss << ", B=[";
×
212
    for (size_t i = 0; i < shape_b_.size(); ++i) {
×
213
        if (i > 0) ss << ", ";
×
214
        ss << shape_b_[i]->__str__();
×
215
    }
×
216
    ss << "], strides_b=[";
×
217
    for (size_t i = 0; i < strides_b_.size(); ++i) {
×
218
        if (i > 0) ss << ", ";
×
219
        ss << strides_b_[i]->__str__();
×
220
    }
×
221
    ss << "], offset_b=" << offset_b_->__str__();
×
222
    ss << ")";
×
223
    return ss.str();
×
224
}
×
225

226
void free_after_copy(
227
    const std::string& copy_name, builder::StructuredSDFGBuilder& builder, structured_control_flow::Sequence& parent
228
) {
×
229
    auto& block = builder.add_block(parent, {}, DebugInfo());
×
230
    auto& access_in = builder.add_access(block, copy_name);
×
231
    auto& access_out = builder.add_access(block, copy_name);
×
232
    auto& free_node = builder.add_library_node<stdlib::FreeNode>(block, DebugInfo());
×
233
    builder.add_computational_memlet(
×
234
        block, access_in, free_node, "_ptr", {}, types::Pointer(types::Scalar(types::PrimitiveType::Void))
×
235
    );
×
236
    builder.add_computational_memlet(
×
237
        block, free_node, "_ptr", access_out, {}, types::Pointer(types::Scalar(types::PrimitiveType::Void))
×
238
    );
×
239
}
×
240

241
bool MatMulNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
242
    auto& dataflow = this->get_parent();
5✔
243
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
5✔
244

245
    if (dataflow.in_degree(*this) != 2 || dataflow.out_degree(*this) != 1) {
5✔
246
        return false;
×
247
    }
×
248

249
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
250
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
5✔
251
    int index = parent.index(block);
5✔
252
    auto& transition = parent.at(index).second;
5✔
253

254
    // Get input and output edges
255
    auto iedges = dataflow.in_edges_by_connector(*this);
5✔
256
    if (iedges.size() != 2) {
5✔
NEW
257
        return false;
×
UNCOV
258
    }
×
259
    auto* iedge_a = iedges.at(0);
5✔
260
    auto* iedge_b = iedges.at(1);
5✔
261
    auto oedges = dataflow.out_edges_by_connector(*this);
5✔
262
    if (oedges.size() != 1) {
5✔
263
        return false;
×
264
    }
×
265
    auto* oedge = oedges.at(0);
5✔
266

267
    // Check if legal - access nodes must not have other connections
268
    auto& input_node_a = static_cast<data_flow::AccessNode&>(iedge_a->src());
5✔
269
    auto& input_node_b = static_cast<data_flow::AccessNode&>(iedge_b->src());
5✔
270
    auto& output_node = static_cast<data_flow::AccessNode&>(oedge->dst());
5✔
271

272
    if (dataflow.in_degree(input_node_a) != 0 || dataflow.in_degree(input_node_b) != 0 ||
5✔
273
        dataflow.out_degree(output_node) != 0) {
5✔
274
        return false;
×
275
    }
×
276

277
    // Determine BLAS precision from primitive type
278
    auto prim_type = this->primitive_type(dataflow);
5✔
279
    blas::BLAS_Precision precision;
5✔
280
    switch (prim_type) {
5✔
281
        case types::PrimitiveType::Half:
×
282
            precision = blas::BLAS_Precision::h;
×
283
            break;
×
284
        case types::PrimitiveType::Float:
3✔
285
            precision = blas::BLAS_Precision::s;
3✔
286
            break;
3✔
287
        case types::PrimitiveType::Double:
1✔
288
            precision = blas::BLAS_Precision::d;
1✔
289
            break;
1✔
290
        default:
1✔
291
            // GEMM only supports floating point types, fall back to naive expansion
292
            return false;
1✔
293
    };
5✔
294

295
    // Add new graph after the current block
296
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
4✔
297

298
    auto copy_name_a = input_node_a.data();
4✔
299
    auto basic_strides_a = types::Tensor::strides_from_shape(shape_a_);
4✔
300
    auto copy_name_b = input_node_b.data();
4✔
301

302
    // Check if A and B have basic strides and whether they are transposed in the last dimension
303
    blas::BLAS_Transpose trans_a, trans_b;
4✔
304
    if (MatMulNode::has_basic_strides(this->shape_a(), this->strides_a())) {
4✔
305
        trans_a = blas::BLAS_Transpose::No;
4✔
306
    } else if (MatMulNode::has_transposed_strides(this->shape_a(), this->strides_a())) {
4✔
NEW
307
        trans_a = blas::BLAS_Transpose::Trans;
×
NEW
308
    } else {
×
NEW
309
        trans_a = blas::BLAS_Transpose::No;
×
NEW
310
        throw InvalidSDFGException("A must be in c-order");
×
NEW
311
    }
×
312
    if (MatMulNode::has_basic_strides(this->shape_b(), this->strides_b())) {
4✔
313
        trans_b = blas::BLAS_Transpose::No;
4✔
314
    } else if (MatMulNode::has_transposed_strides(this->shape_b(), this->strides_b())) {
4✔
NEW
315
        trans_b = blas::BLAS_Transpose::Trans;
×
NEW
316
    } else {
×
NEW
317
        trans_b = blas::BLAS_Transpose::No;
×
NEW
318
        throw InvalidSDFGException("B must be in c-order");
×
NEW
319
    }
×
320

321
    // Create maps for batch dimensions and M, N dimensions
322
    structured_control_flow::Sequence* last_scope = &new_sequence;
4✔
323
    structured_control_flow::Map* last_map = nullptr;
4✔
324
    symbolic::MultiExpression batch_vars;
4✔
325

326
    // Compute batch dimensions (all except last 2)
327
    size_t batch_dims_a = shape_a_.size() - 2;
4✔
328
    size_t batch_dims_b = shape_b_.size() - 2;
4✔
329
    size_t max_batch_dims = std::max(batch_dims_a, batch_dims_b);
4✔
330

331
    // Create maps for batch dimensions (using broadcasting)
332
    for (size_t i = 0; i < max_batch_dims; ++i) {
5✔
333
        std::string indvar_str = builder.find_new_name("_b");
1✔
334
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
1✔
335

336
        auto indvar = symbolic::symbol(indvar_str);
1✔
337
        auto init = symbolic::zero();
1✔
338
        auto update = symbolic::add(indvar, symbolic::one());
1✔
339

340
        // Determine the bound for this batch dimension (max of A and B for broadcasting)
341
        symbolic::Expression bound;
1✔
342
        size_t a_idx = batch_dims_a >= (max_batch_dims - i) ? i - (max_batch_dims - batch_dims_a) : SIZE_MAX;
1✔
343
        size_t b_idx = batch_dims_b >= (max_batch_dims - i) ? i - (max_batch_dims - batch_dims_b) : SIZE_MAX;
1✔
344

345
        if (a_idx != SIZE_MAX && b_idx != SIZE_MAX) {
1✔
346
            // Both have this dimension - they should be equal or one should be 1 (broadcasting)
347
            bound = shape_a_[a_idx]; // Assume they match or broadcasting is handled
1✔
348
        } else if (a_idx != SIZE_MAX) {
1✔
349
            bound = shape_a_[a_idx];
×
350
        } else {
×
351
            bound = shape_b_[b_idx];
×
352
        }
×
353

354
        auto condition = symbolic::Lt(indvar, bound);
1✔
355
        last_map = &builder.add_map(
1✔
356
            *last_scope,
1✔
357
            indvar,
1✔
358
            condition,
1✔
359
            init,
1✔
360
            update,
1✔
361
            structured_control_flow::ScheduleType_Sequential::create(),
1✔
362
            {},
1✔
363
            block.debug_info()
1✔
364
        );
1✔
365
        last_scope = &last_map->root();
1✔
366
        batch_vars.push_back(indvar);
1✔
367
    }
1✔
368

369
    auto& ref_block = builder.add_block(*last_scope, {}, block.debug_info());
4✔
370

371
    auto scalar_type = types::Scalar(prim_type);
4✔
372

373
    // Compute offsets for this batch iteration
374
    // For A: base_offset_a = offset_a + sum_i(batch_idx_i * batch_stride_a_i)
375
    symbolic::Expression a_batch_offset = offset_a_;
4✔
376
    for (size_t i = 0; i < batch_dims_a; ++i) {
5✔
377
        size_t batch_idx = max_batch_dims - batch_dims_a + i;
1✔
378
        a_batch_offset = symbolic::add(a_batch_offset, symbolic::mul(batch_vars[batch_idx], strides_a_[i]));
1✔
379
    }
1✔
380

381
    // For B: base_offset_b = offset_b + sum_i(batch_idx_i * batch_stride_b_i)
382
    symbolic::Expression b_batch_offset = offset_b_;
4✔
383
    for (size_t i = 0; i < batch_dims_b; ++i) {
5✔
384
        size_t batch_idx = max_batch_dims - batch_dims_b + i;
1✔
385
        b_batch_offset = symbolic::add(b_batch_offset, symbolic::mul(batch_vars[batch_idx], strides_b_[i]));
1✔
386
    }
1✔
387

388
    // Compute output batch offset (same as batch_vars pattern for Y)
389
    symbolic::Expression c_batch_offset = symbolic::integer(0);
4✔
390
    for (size_t i = 0; i < batch_vars.size(); ++i) {
5✔
391
        // Output has shape [batch..., M, N] with row-major strides
392
        // Stride for batch dim i is: M * N * product of remaining batch dims
393
        symbolic::Expression c_stride = symbolic::mul(this->m(), this->n());
1✔
394
        for (size_t j = i + 1; j < batch_vars.size(); ++j) {
1✔
395
            // Multiply by subsequent batch dimensions
396
            if (j < batch_dims_a) {
×
397
                c_stride = symbolic::mul(c_stride, shape_a_[j]);
×
398
            } else if (j - batch_dims_a < batch_dims_b) {
×
399
                c_stride = symbolic::mul(c_stride, shape_b_[j - batch_dims_a]);
×
400
            }
×
401
        }
×
402
        c_batch_offset = symbolic::add(c_batch_offset, symbolic::mul(batch_vars[i], c_stride));
1✔
403
    }
1✔
404

405
    // Create access nodes
406
    auto& a_access = builder.add_access(ref_block, copy_name_a, debug_info());
4✔
407
    auto& b_access = builder.add_access(ref_block, copy_name_b, debug_info());
4✔
408
    auto& c_access_in = builder.add_access(ref_block, output_node.data(), debug_info());
4✔
409

410
    std::string ref_name_a = builder.find_new_name(copy_name_a + "_ref");
4✔
411
    builder.add_container(ref_name_a, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
412
    auto& a_access_ref = builder.add_access(ref_block, ref_name_a, debug_info());
4✔
413
    std::string ref_name_b = builder.find_new_name(copy_name_b + "_ref");
4✔
414
    builder.add_container(ref_name_b, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
415
    auto& b_access_ref = builder.add_access(ref_block, ref_name_b, debug_info());
4✔
416
    std::string ref_name_c = builder.find_new_name(output_node.data() + "_ref");
4✔
417
    builder.add_container(ref_name_c, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
418
    auto& c_access_ref_in = builder.add_access(ref_block, ref_name_c, debug_info());
4✔
419

420
    builder.add_reference_memlet(
4✔
421
        ref_block, a_access, a_access_ref, {a_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
422
    );
4✔
423
    builder.add_reference_memlet(
4✔
424
        ref_block, b_access, b_access_ref, {b_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
425
    );
4✔
426
    builder.add_reference_memlet(
4✔
427
        ref_block, c_access_in, c_access_ref_in, {c_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
428
    );
4✔
429

430
    // Create block with GEMM library node
431
    auto& gemm_block = builder.add_block(*last_scope, {}, block.debug_info());
4✔
432

433
    // Leading dimensions: stride of the row dimension (second-to-last dim)
434
    symbolic::Expression lda, ldb;
4✔
435
    if (trans_a == blas::BLAS_Transpose::No) {
4✔
436
        // For row-major A [m * k] -> lda = k
437
        lda = strides_a_[strides_a_.size() - 2];
4✔
438
    } else {
4✔
439
        // For row-major A [m * k] -> lda = m
NEW
440
        ldb = strides_a_[strides_a_.size() - 1];
×
NEW
441
    }
×
442
    if (trans_b == blas::BLAS_Transpose::No) {
4✔
443
        // For row-major B [k * n] -> ldb = n
444
        ldb = strides_b_[strides_b_.size() - 2];
4✔
445
    } else {
4✔
446
        // For row-major B [k * n] -> ldb = k
NEW
447
        ldb = strides_b_[strides_b_.size() - 1];
×
NEW
448
    }
×
449
    // For row-major C [m * n] -> ldc = n
450
    auto ldc = this->n();
4✔
451

452
    // Add GEMM node: C = alpha * A * B + beta * C
453
    // With alpha = 1.0, beta = 0.0: C = A * B
454
    auto& gemm_node = builder.add_library_node<blas::GEMMNode>(
4✔
455
        gemm_block,
4✔
456
        debug_info(),
4✔
457
        blas::ImplementationType_BLAS,
4✔
458
        precision,
4✔
459
        blas::BLAS_Layout::RowMajor,
4✔
460
        trans_a,
4✔
461
        trans_b,
4✔
462
        this->m(),
4✔
463
        this->n(),
4✔
464
        this->k(),
4✔
465
        lda,
4✔
466
        ldb,
4✔
467
        ldc
4✔
468
    );
4✔
469

470
    auto& a_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_a, debug_info());
4✔
471
    auto& b_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_b, debug_info());
4✔
472
    auto& c_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_c, debug_info());
4✔
473

474
    auto& c_access_ref_out = builder.add_access(gemm_block, ref_name_c, debug_info());
4✔
475

476
    // Create alpha and beta constants
477
    auto& alpha_const = builder.add_constant(gemm_block, "1.0", scalar_type, debug_info());
4✔
478
    auto& beta_const = builder.add_constant(gemm_block, "0.0", scalar_type, debug_info());
4✔
479

480
    // Connect memlets with batch offsets
481
    // Input A with offset
482
    builder.add_computational_memlet(
4✔
483
        gemm_block, a_access_ref_in_gemm, gemm_node, "__A", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
484
    );
4✔
485
    // Input B with offset
486
    builder.add_computational_memlet(
4✔
487
        gemm_block, b_access_ref_in_gemm, gemm_node, "__B", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
488
    );
4✔
489
    // Input C (for beta * C, but beta=0 so just needs to be connected)
490
    builder.add_computational_memlet(
4✔
491
        gemm_block, c_access_ref_in_gemm, gemm_node, "__C", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
492
    );
4✔
493
    // Alpha constant
494
    builder.add_computational_memlet(gemm_block, alpha_const, gemm_node, "__alpha", {}, scalar_type, debug_info());
4✔
495
    // Beta constant
496
    builder.add_computational_memlet(gemm_block, beta_const, gemm_node, "__beta", {}, scalar_type, debug_info());
4✔
497
    // Output C
498
    builder.add_computational_memlet(
4✔
499
        gemm_block, gemm_node, "__C", c_access_ref_out, {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
500
    );
4✔
501

502
    // Free copies if we made them
503
    if (copy_name_a != input_node_a.data()) {
4✔
504
        free_after_copy(copy_name_a, builder, new_sequence);
×
505
    }
×
506
    if (copy_name_b != input_node_b.data()) {
4✔
507
        free_after_copy(copy_name_b, builder, new_sequence);
×
508
    }
×
509

510
    // Remove the original nodes
511
    builder.remove_memlet(block, *iedge_a);
4✔
512
    builder.remove_memlet(block, *iedge_b);
4✔
513
    builder.remove_memlet(block, *oedge);
4✔
514
    if (&input_node_a != &input_node_b) {
4✔
515
        builder.remove_node(block, input_node_a);
4✔
516
    }
4✔
517
    builder.remove_node(block, input_node_b);
4✔
518
    builder.remove_node(block, output_node);
4✔
519
    builder.remove_node(block, *this);
4✔
520
    builder.remove_child(parent, index + 1);
4✔
521

522
    return true;
4✔
523
}
4✔
524

525
nlohmann::json MatMulNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
526
    const MatMulNode& matmul_node = static_cast<const MatMulNode&>(library_node);
×
527
    nlohmann::json j;
×
528

529
    j["code"] = matmul_node.code().value();
×
530

531
    serializer::JSONSerializer serializer;
×
532

533
    j["shape_a"] = nlohmann::json::array();
×
534
    for (auto& dim : matmul_node.shape_a()) {
×
535
        j["shape_a"].push_back(serializer.expression(dim));
×
536
    }
×
537

538
    j["shape_b"] = nlohmann::json::array();
×
539
    for (auto& dim : matmul_node.shape_b()) {
×
540
        j["shape_b"].push_back(serializer.expression(dim));
×
541
    }
×
542

543
    j["strides_a"] = nlohmann::json::array();
×
544
    for (auto& stride : matmul_node.strides_a()) {
×
545
        j["strides_a"].push_back(serializer.expression(stride));
×
546
    }
×
547

548
    j["strides_b"] = nlohmann::json::array();
×
549
    for (auto& stride : matmul_node.strides_b()) {
×
550
        j["strides_b"].push_back(serializer.expression(stride));
×
551
    }
×
552

553
    j["offset_a"] = serializer.expression(matmul_node.offset_a());
×
554
    j["offset_b"] = serializer.expression(matmul_node.offset_b());
×
555

556
    return j;
×
557
}
×
558

559
data_flow::LibraryNode& MatMulNodeSerializer::deserialize(
560
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
561
) {
×
562
    assert(j.contains("element_id"));
×
563
    assert(j.contains("code"));
×
564
    assert(j.contains("debug_info"));
×
565
    assert(j.contains("shape_a"));
×
566
    assert(j.contains("shape_b"));
×
567

568
    symbolic::MultiExpression shape_a;
×
569
    for (const auto& dim : j["shape_a"]) {
×
570
        shape_a.push_back(symbolic::parse(dim.get<std::string>()));
×
571
    }
×
572

573
    symbolic::MultiExpression shape_b;
×
574
    for (const auto& dim : j["shape_b"]) {
×
575
        shape_b.push_back(symbolic::parse(dim.get<std::string>()));
×
576
    }
×
577

578
    symbolic::MultiExpression strides_a;
×
579
    if (j.contains("strides_a")) {
×
580
        for (const auto& stride : j["strides_a"]) {
×
581
            strides_a.push_back(symbolic::parse(stride.get<std::string>()));
×
582
        }
×
583
    }
×
584

585
    symbolic::MultiExpression strides_b;
×
586
    if (j.contains("strides_b")) {
×
587
        for (const auto& stride : j["strides_b"]) {
×
588
            strides_b.push_back(symbolic::parse(stride.get<std::string>()));
×
589
        }
×
590
    }
×
591

592
    symbolic::Expression offset_a = symbolic::integer(0);
×
593
    if (j.contains("offset_a")) {
×
594
        offset_a = symbolic::parse(j["offset_a"].get<std::string>());
×
595
    }
×
596

597
    symbolic::Expression offset_b = symbolic::integer(0);
×
598
    if (j.contains("offset_b")) {
×
599
        offset_b = symbolic::parse(j["offset_b"].get<std::string>());
×
600
    }
×
601

602
    sdfg::serializer::JSONSerializer serializer;
×
603
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
604

605
    return builder
×
606
        .add_library_node<MatMulNode>(parent, debug_info, shape_a, shape_b, strides_a, strides_b, offset_a, offset_b);
×
607
}
×
608

609
} // namespace tensor
610
} // namespace math
611
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc