• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22963489578

11 Mar 2026 04:36PM UTC coverage: 63.494% (-0.001%) from 63.495%
22963489578

Pull #579

github

web-flow
Merge 986a5904a into 1cb8d452f
Pull Request #579: Skip nested GPU tiling on the same container to avoid redefinitions o…

3 of 5 new or added lines in 1 file covered. (60.0%)

101 existing lines in 2 files now uncovered.

24712 of 38920 relevant lines covered (63.49%)

369.37 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.61
/sdfg/src/data_flow/library_nodes/math/tensor/matmul_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/matmul_node.h"
2
#include <string>
3

4
#include "sdfg/analysis/scope_analysis.h"
5
#include "sdfg/builder/structured_sdfg_builder.h"
6
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
7
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
8
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
9
#include "sdfg/data_flow/tasklet.h"
10
#include "sdfg/element.h"
11
#include "sdfg/structured_control_flow/control_flow_node.h"
12
#include "sdfg/structured_control_flow/map.h"
13
#include "sdfg/structured_control_flow/sequence.h"
14
#include "sdfg/symbolic/symbolic.h"
15
#include "sdfg/types/pointer.h"
16
#include "sdfg/types/scalar.h"
17
#include "sdfg/types/tensor.h"
18
#include "sdfg/types/type.h"
19
#include "sdfg/types/utils.h"
20

21
namespace sdfg {
22
namespace math {
23
namespace tensor {
24

25
MatMulNode::MatMulNode(
26
    size_t element_id,
27
    const DebugInfo& debug_info,
28
    const graph::Vertex vertex,
29
    data_flow::DataFlowGraph& parent,
30
    const symbolic::MultiExpression& shape_a,
31
    const symbolic::MultiExpression& shape_b,
32
    const symbolic::MultiExpression& strides_a,
33
    const symbolic::MultiExpression& strides_b,
34
    symbolic::Expression offset_a,
35
    symbolic::Expression offset_b
36
)
37
    : TensorNode(
7✔
38
          element_id,
7✔
39
          debug_info,
7✔
40
          vertex,
7✔
41
          parent,
7✔
42
          LibraryNodeType_MatMul,
7✔
43
          {"Y"},
7✔
44
          {"A", "B"},
7✔
45
          data_flow::ImplementationType_NONE
7✔
46
      ),
7✔
47
      shape_a_(shape_a), shape_b_(shape_b), strides_a_(strides_a), strides_b_(strides_b), offset_a_(offset_a),
7✔
48
      offset_b_(offset_b) {
7✔
49
    if (shape_a_.size() < 2) {
7✔
50
        throw std::invalid_argument("MatMulNode: Input A must have at least 2 dimensions");
×
51
    }
×
52
    if (shape_b_.size() < 2) {
7✔
53
        throw std::invalid_argument("MatMulNode: Input B must have at least 2 dimensions");
×
54
    }
×
55
    // Compute default row-major strides if not provided
56
    if (strides_a_.empty()) {
7✔
57
        strides_a_.resize(shape_a_.size());
5✔
58
        strides_a_[shape_a_.size() - 1] = symbolic::integer(1);
5✔
59
        for (int i = static_cast<int>(shape_a_.size()) - 2; i >= 0; --i) {
11✔
60
            strides_a_[i] = symbolic::mul(strides_a_[i + 1], shape_a_[i + 1]);
6✔
61
        }
6✔
62
    }
5✔
63
    if (strides_b_.empty()) {
7✔
64
        strides_b_.resize(shape_b_.size());
6✔
65
        strides_b_[shape_b_.size() - 1] = symbolic::integer(1);
6✔
66
        for (int i = static_cast<int>(shape_b_.size()) - 2; i >= 0; --i) {
13✔
67
            strides_b_[i] = symbolic::mul(strides_b_[i + 1], shape_b_[i + 1]);
7✔
68
        }
7✔
69
    }
6✔
70
}
7✔
71

72
symbolic::Expression MatMulNode::m() const {
10✔
73
    // M is the second-to-last dimension of A
74
    return shape_a_[shape_a_.size() - 2];
10✔
75
}
10✔
76

77
symbolic::Expression MatMulNode::n() const {
16✔
78
    // N is the last dimension of B
79
    return shape_b_[shape_b_.size() - 1];
16✔
80
}
16✔
81

82
symbolic::Expression MatMulNode::k() const {
9✔
83
    // K is the last dimension of A (and second-to-last of B)
84
    return shape_a_[shape_a_.size() - 1];
9✔
85
}
9✔
86

87
void MatMulNode::validate(const Function& function) const {
7✔
88
    TensorNode::validate(function);
7✔
89

90
    auto& graph = this->get_parent();
7✔
91

92
    // Check that we have exactly 2 inputs and 1 output
93
    if (graph.in_degree(*this) != 2) {
7✔
94
        throw InvalidSDFGException("MatMulNode: Expected exactly 2 inputs (A and B)");
×
95
    }
×
96
    if (graph.out_degree(*this) != 1) {
7✔
97
        throw InvalidSDFGException("MatMulNode: Expected exactly 1 output (Y)");
×
98
    }
×
99

100
    // Validate K dimension matches between A and B
101
    auto k_a = shape_a_[shape_a_.size() - 1];
7✔
102
    auto k_b = shape_b_[shape_b_.size() - 2];
7✔
103
    if (!symbolic::eq(k_a, k_b)) {
7✔
104
        throw InvalidSDFGException(
×
105
            "MatMulNode: K dimension mismatch. A has K=" + k_a->__str__() + ", B has K=" + k_b->__str__()
×
106
        );
×
107
    }
×
108
}
7✔
109

110
symbolic::SymbolSet MatMulNode::symbols() const {
1✔
111
    symbolic::SymbolSet syms;
1✔
112
    for (const auto& dim : shape_a_) {
2✔
113
        for (auto& atom : symbolic::atoms(dim)) {
2✔
114
            syms.insert(atom);
2✔
115
        }
2✔
116
    }
2✔
117
    for (const auto& dim : shape_b_) {
2✔
118
        for (auto& atom : symbolic::atoms(dim)) {
2✔
119
            syms.insert(atom);
2✔
120
        }
2✔
121
    }
2✔
122
    for (const auto& stride : strides_a_) {
2✔
123
        for (auto& atom : symbolic::atoms(stride)) {
2✔
124
            syms.insert(atom);
1✔
125
        }
1✔
126
    }
2✔
127
    for (const auto& stride : strides_b_) {
2✔
128
        for (auto& atom : symbolic::atoms(stride)) {
2✔
129
            syms.insert(atom);
1✔
130
        }
1✔
131
    }
2✔
132
    for (auto& atom : symbolic::atoms(offset_a_)) {
1✔
133
        syms.insert(atom);
×
134
    }
×
135
    for (auto& atom : symbolic::atoms(offset_b_)) {
1✔
136
        syms.insert(atom);
×
137
    }
×
138
    return syms;
1✔
139
}
1✔
140

141
void MatMulNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
142
    for (auto& dim : shape_a_) {
×
143
        dim = symbolic::subs(dim, old_expression, new_expression);
×
144
    }
×
145
    for (auto& dim : shape_b_) {
×
146
        dim = symbolic::subs(dim, old_expression, new_expression);
×
147
    }
×
148
    for (auto& stride : strides_a_) {
×
149
        stride = symbolic::subs(stride, old_expression, new_expression);
×
150
    }
×
151
    for (auto& stride : strides_b_) {
×
152
        stride = symbolic::subs(stride, old_expression, new_expression);
×
153
    }
×
154
    offset_a_ = symbolic::subs(offset_a_, old_expression, new_expression);
×
155
    offset_b_ = symbolic::subs(offset_b_, old_expression, new_expression);
×
156
}
×
157

158
std::unique_ptr<data_flow::DataFlowNode> MatMulNode::
159
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
160
    return std::unique_ptr<data_flow::DataFlowNode>(new MatMulNode(
×
161
        element_id, debug_info(), vertex, parent, shape_a_, shape_b_, strides_a_, strides_b_, offset_a_, offset_b_
×
162
    ));
×
163
}
×
164

165
std::string MatMulNode::toStr() const {
×
166
    std::stringstream ss;
×
167
    ss << "MatMul(";
×
168
    ss << "A=[";
×
169
    for (size_t i = 0; i < shape_a_.size(); ++i) {
×
170
        if (i > 0) ss << ", ";
×
171
        ss << shape_a_[i]->__str__();
×
172
    }
×
173
    ss << "], strides_a=[";
×
174
    for (size_t i = 0; i < strides_a_.size(); ++i) {
×
175
        if (i > 0) ss << ", ";
×
176
        ss << strides_a_[i]->__str__();
×
177
    }
×
178
    ss << "], offset_a=" << offset_a_->__str__();
×
179
    ss << ", B=[";
×
180
    for (size_t i = 0; i < shape_b_.size(); ++i) {
×
181
        if (i > 0) ss << ", ";
×
182
        ss << shape_b_[i]->__str__();
×
183
    }
×
184
    ss << "], strides_b=[";
×
185
    for (size_t i = 0; i < strides_b_.size(); ++i) {
×
186
        if (i > 0) ss << ", ";
×
187
        ss << strides_b_[i]->__str__();
×
188
    }
×
189
    ss << "], offset_b=" << offset_b_->__str__();
×
190
    ss << ")";
×
191
    return ss.str();
×
192
}
×
193

194
std::string copy_if_view(
195
    const std::string& name,
196
    builder::StructuredSDFGBuilder& builder,
197
    structured_control_flow::Sequence& parent,
198
    types::PrimitiveType type,
199
    const symbolic::MultiExpression& shape,
200
    const symbolic::MultiExpression& strides,
201
    symbolic::Expression offset
202
) {
12✔
203
    // If the tensor is already a view (has non-default strides or offset), we need to create a copy to ensure correct
204
    // semantics
205
    types::Tensor tensor_type(type, shape, strides, offset);
12✔
206

207
    auto C_style_strides = tensor_type.strides_from_shape(shape);
12✔
208

209
    bool is_view = false;
12✔
210
    for (size_t i = 0; i < strides.size(); ++i) {
32✔
211
        if (!symbolic::eq(strides[i], C_style_strides[i])) {
23✔
212
            is_view = true;
3✔
213
            break;
3✔
214
        }
3✔
215
    }
23✔
216

217
    if (is_view) {
12✔
218
        std::string copy_name = builder.find_new_name(name + "_copy");
3✔
219
        types::Pointer copy_type((types::Scalar(types::PrimitiveType::Void)));
3✔
220
        builder.add_container(copy_name, copy_type);
3✔
221
        symbolic::Expression num_elements = symbolic::one();
3✔
222
        for (const auto& dim : shape) {
6✔
223
            num_elements = symbolic::mul(num_elements, dim);
6✔
224
        }
6✔
225
        auto elem_size = types::get_type_size(types::Scalar(type));
3✔
226
        auto copy_size = symbolic::mul(num_elements, elem_size);
3✔
227

228
        // Allocate a C-order copy
229
        auto& alloc_block = builder.add_block(parent, {}, DebugInfo());
3✔
230
        auto& out_access = builder.add_access(alloc_block, copy_name);
3✔
231
        auto& malloc_node = builder.add_library_node<stdlib::MallocNode>(alloc_block, DebugInfo(), copy_size);
3✔
232
        builder.add_computational_memlet(
3✔
233
            alloc_block, malloc_node, "_ret", out_access, {}, types::Pointer(types::Scalar(type))
3✔
234
        );
3✔
235

236
        // Build a loop nest over each dimension
237
        structured_control_flow::Sequence* inner_scope = &parent;
3✔
238
        std::vector<symbolic::Expression> loop_vars;
3✔
239
        std::vector<symbolic::Expression> orig_accesses;
3✔
240
        for (size_t i = 0; i < shape.size(); ++i) {
9✔
241
            std::string indvar_str = builder.find_new_name(name + "_ci");
6✔
242
            builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
6✔
243
            auto indvar = symbolic::symbol(indvar_str);
6✔
244
            auto init = symbolic::zero();
6✔
245
            auto update = symbolic::add(indvar, symbolic::one());
6✔
246
            auto condition = symbolic::Lt(indvar, shape[i]);
6✔
247
            auto& copy_map =
6✔
248
                builder.add_map(*inner_scope, indvar, condition, init, update, ScheduleType_Sequential::create());
6✔
249
            inner_scope = &copy_map.root();
6✔
250
            loop_vars.push_back(indvar);
6✔
251
        }
6✔
252

253
        // Inside the innermost loop: copy one element
254
        auto& copy_block = builder.add_block(*inner_scope);
3✔
255
        auto& in_access_copy = builder.add_access(copy_block, name);
3✔
256
        auto& out_access_copy = builder.add_access(copy_block, copy_name);
3✔
257
        auto& tasklet = builder.add_tasklet(copy_block, data_flow::TaskletCode::assign, "_out", {"_in"});
3✔
258

259
        // Read with original strides/offset
260
        builder.add_computational_memlet(copy_block, in_access_copy, tasklet, "_in", loop_vars, tensor_type);
3✔
261
        // Write with C-order strides (default strides, zero offset)
262
        types::Tensor c_order_type(type, shape);
3✔
263
        builder.add_computational_memlet(copy_block, tasklet, "_out", out_access_copy, loop_vars, c_order_type);
3✔
264

265
        return copy_name;
3✔
266
    } else {
9✔
267
        return name;
9✔
268
    }
9✔
269
}
12✔
270

271
void free_after_copy(
272
    const std::string& copy_name, builder::StructuredSDFGBuilder& builder, structured_control_flow::Sequence& parent
273
) {
3✔
274
    auto& block = builder.add_block(parent, {}, DebugInfo());
3✔
275
    auto& access_in = builder.add_access(block, copy_name);
3✔
276
    auto& access_out = builder.add_access(block, copy_name);
3✔
277
    auto& free_node = builder.add_library_node<stdlib::FreeNode>(block, DebugInfo());
3✔
278
    builder.add_computational_memlet(
3✔
279
        block, access_in, free_node, "_ptr", {}, types::Pointer(types::Scalar(types::PrimitiveType::Void))
3✔
280
    );
3✔
281
    builder.add_computational_memlet(
3✔
282
        block, free_node, "_ptr", access_out, {}, types::Pointer(types::Scalar(types::PrimitiveType::Void))
3✔
283
    );
3✔
284
}
3✔
285

286
bool MatMulNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
7✔
287
    auto& dataflow = this->get_parent();
7✔
288
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
7✔
289

290
    if (dataflow.in_degree(*this) != 2 || dataflow.out_degree(*this) != 1) {
7✔
UNCOV
291
        return false;
×
UNCOV
292
    }
×
293

294
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
7✔
295
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
7✔
296
    int index = parent.index(block);
7✔
297
    auto& transition = parent.at(index).second;
7✔
298

299
    // Get input and output edges
300
    data_flow::Memlet* iedge_a = nullptr;
7✔
301
    data_flow::Memlet* iedge_b = nullptr;
7✔
302
    for (auto& iedge : dataflow.in_edges(*this)) {
14✔
303
        if (iedge.dst_conn() == "A") {
14✔
304
            iedge_a = &iedge;
7✔
305
        } else if (iedge.dst_conn() == "B") {
7✔
306
            iedge_b = &iedge;
7✔
307
        }
7✔
308
    }
14✔
309
    auto& oedge = *dataflow.out_edges(*this).begin();
7✔
310

311
    if (!iedge_a || !iedge_b) {
7✔
UNCOV
312
        return false;
×
UNCOV
313
    }
×
314

315
    // Check if legal - access nodes must not have other connections
316
    auto& input_node_a = static_cast<data_flow::AccessNode&>(iedge_a->src());
7✔
317
    auto& input_node_b = static_cast<data_flow::AccessNode&>(iedge_b->src());
7✔
318
    auto& output_node = static_cast<data_flow::AccessNode&>(oedge.dst());
7✔
319

320
    if (dataflow.in_degree(input_node_a) != 0 || dataflow.in_degree(input_node_b) != 0 ||
7✔
321
        dataflow.out_degree(output_node) != 0) {
7✔
UNCOV
322
        return false;
×
UNCOV
323
    }
×
324

325
    // Determine BLAS precision from primitive type
326
    auto prim_type = this->primitive_type(dataflow);
7✔
327
    blas::BLAS_Precision precision;
7✔
328
    switch (prim_type) {
7✔
UNCOV
329
        case types::PrimitiveType::Half:
×
UNCOV
330
            precision = blas::BLAS_Precision::h;
×
UNCOV
331
            break;
×
332
        case types::PrimitiveType::Float:
5✔
333
            precision = blas::BLAS_Precision::s;
5✔
334
            break;
5✔
335
        case types::PrimitiveType::Double:
1✔
336
            precision = blas::BLAS_Precision::d;
1✔
337
            break;
1✔
338
        default:
1✔
339
            // GEMM only supports floating point types, fall back to naive expansion
340
            return false;
1✔
341
    };
7✔
342

343
    // Add new graph after the current block
344
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
6✔
345

346
    auto copy_name_a =
6✔
347
        copy_if_view(input_node_a.data(), builder, new_sequence, prim_type, shape_a_, strides_a_, offset_a_);
6✔
348
    strides_a_ = types::Tensor::strides_from_shape(shape_a_);
6✔
349
    auto copy_name_b =
6✔
350
        copy_if_view(input_node_b.data(), builder, new_sequence, prim_type, shape_b_, strides_b_, offset_b_);
6✔
351
    strides_b_ = types::Tensor::strides_from_shape(shape_b_);
6✔
352

353
    // Create maps for batch dimensions and M, N dimensions
354
    structured_control_flow::Sequence* last_scope = &new_sequence;
6✔
355
    structured_control_flow::Map* last_map = nullptr;
6✔
356
    symbolic::MultiExpression batch_vars;
6✔
357

358
    // Compute batch dimensions (all except last 2)
359
    size_t batch_dims_a = shape_a_.size() - 2;
6✔
360
    size_t batch_dims_b = shape_b_.size() - 2;
6✔
361
    size_t max_batch_dims = std::max(batch_dims_a, batch_dims_b);
6✔
362

363
    // Create maps for batch dimensions (using broadcasting)
364
    for (size_t i = 0; i < max_batch_dims; ++i) {
7✔
365
        std::string indvar_str = builder.find_new_name("_b");
1✔
366
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
1✔
367

368
        auto indvar = symbolic::symbol(indvar_str);
1✔
369
        auto init = symbolic::zero();
1✔
370
        auto update = symbolic::add(indvar, symbolic::one());
1✔
371

372
        // Determine the bound for this batch dimension (max of A and B for broadcasting)
373
        symbolic::Expression bound;
1✔
374
        size_t a_idx = batch_dims_a >= (max_batch_dims - i) ? i - (max_batch_dims - batch_dims_a) : SIZE_MAX;
1✔
375
        size_t b_idx = batch_dims_b >= (max_batch_dims - i) ? i - (max_batch_dims - batch_dims_b) : SIZE_MAX;
1✔
376

377
        if (a_idx != SIZE_MAX && b_idx != SIZE_MAX) {
1✔
378
            // Both have this dimension - they should be equal or one should be 1 (broadcasting)
379
            bound = shape_a_[a_idx]; // Assume they match or broadcasting is handled
1✔
380
        } else if (a_idx != SIZE_MAX) {
1✔
UNCOV
381
            bound = shape_a_[a_idx];
×
UNCOV
382
        } else {
×
UNCOV
383
            bound = shape_b_[b_idx];
×
UNCOV
384
        }
×
385

386
        auto condition = symbolic::Lt(indvar, bound);
1✔
387
        last_map = &builder.add_map(
1✔
388
            *last_scope,
1✔
389
            indvar,
1✔
390
            condition,
1✔
391
            init,
1✔
392
            update,
1✔
393
            structured_control_flow::ScheduleType_Sequential::create(),
1✔
394
            {},
1✔
395
            block.debug_info()
1✔
396
        );
1✔
397
        last_scope = &last_map->root();
1✔
398
        batch_vars.push_back(indvar);
1✔
399
    }
1✔
400

401
    auto& ref_block = builder.add_block(*last_scope, {}, block.debug_info());
6✔
402

403
    auto scalar_type = types::Scalar(prim_type);
6✔
404

405
    // Compute offsets for this batch iteration
406
    // For A: base_offset_a = offset_a + sum_i(batch_idx_i * batch_stride_a_i)
407
    symbolic::Expression a_batch_offset = offset_a_;
6✔
408
    for (size_t i = 0; i < batch_dims_a; ++i) {
7✔
409
        size_t batch_idx = max_batch_dims - batch_dims_a + i;
1✔
410
        a_batch_offset = symbolic::add(a_batch_offset, symbolic::mul(batch_vars[batch_idx], strides_a_[i]));
1✔
411
    }
1✔
412

413
    // For B: base_offset_b = offset_b + sum_i(batch_idx_i * batch_stride_b_i)
414
    symbolic::Expression b_batch_offset = offset_b_;
6✔
415
    for (size_t i = 0; i < batch_dims_b; ++i) {
7✔
416
        size_t batch_idx = max_batch_dims - batch_dims_b + i;
1✔
417
        b_batch_offset = symbolic::add(b_batch_offset, symbolic::mul(batch_vars[batch_idx], strides_b_[i]));
1✔
418
    }
1✔
419

420
    // Compute output batch offset (same as batch_vars pattern for Y)
421
    symbolic::Expression c_batch_offset = symbolic::integer(0);
6✔
422
    for (size_t i = 0; i < batch_vars.size(); ++i) {
7✔
423
        // Output has shape [batch..., M, N] with row-major strides
424
        // Stride for batch dim i is: M * N * product of remaining batch dims
425
        symbolic::Expression c_stride = symbolic::mul(this->m(), this->n());
1✔
426
        for (size_t j = i + 1; j < batch_vars.size(); ++j) {
1✔
427
            // Multiply by subsequent batch dimensions
UNCOV
428
            if (j < batch_dims_a) {
×
UNCOV
429
                c_stride = symbolic::mul(c_stride, shape_a_[j]);
×
UNCOV
430
            } else if (j - batch_dims_a < batch_dims_b) {
×
UNCOV
431
                c_stride = symbolic::mul(c_stride, shape_b_[j - batch_dims_a]);
×
UNCOV
432
            }
×
UNCOV
433
        }
×
434
        c_batch_offset = symbolic::add(c_batch_offset, symbolic::mul(batch_vars[i], c_stride));
1✔
435
    }
1✔
436

437
    // Create access nodes
438
    auto& a_access = builder.add_access(ref_block, copy_name_a, debug_info());
6✔
439
    auto& b_access = builder.add_access(ref_block, copy_name_b, debug_info());
6✔
440
    auto& c_access_in = builder.add_access(ref_block, output_node.data(), debug_info());
6✔
441

442
    std::string ref_name_a = builder.find_new_name(copy_name_a + "_ref");
6✔
443
    builder.add_container(ref_name_a, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
6✔
444
    auto& a_access_ref = builder.add_access(ref_block, ref_name_a, debug_info());
6✔
445
    std::string ref_name_b = builder.find_new_name(copy_name_b + "_ref");
6✔
446
    builder.add_container(ref_name_b, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
6✔
447
    auto& b_access_ref = builder.add_access(ref_block, ref_name_b, debug_info());
6✔
448
    std::string ref_name_c = builder.find_new_name(output_node.data() + "_ref");
6✔
449
    builder.add_container(ref_name_c, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
6✔
450
    auto& c_access_ref_in = builder.add_access(ref_block, ref_name_c, debug_info());
6✔
451

452
    builder.add_reference_memlet(
6✔
453
        ref_block, a_access, a_access_ref, {a_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
6✔
454
    );
6✔
455
    builder.add_reference_memlet(
6✔
456
        ref_block, b_access, b_access_ref, {b_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
6✔
457
    );
6✔
458
    builder.add_reference_memlet(
6✔
459
        ref_block, c_access_in, c_access_ref_in, {c_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
6✔
460
    );
6✔
461

462
    // Create block with GEMM library node
463
    auto& gemm_block = builder.add_block(*last_scope, {}, block.debug_info());
6✔
464

465
    // Leading dimensions: stride of the row dimension (second-to-last dim)
466
    // For row-major A[M, K]: lda = stride for M dimension = strides_a_[-2]
467
    // For row-major B[K, N]: ldb = stride for K dimension = strides_b_[-2]
468
    auto lda = strides_a_[strides_a_.size() - 2];
6✔
469
    auto ldb = strides_b_[strides_b_.size() - 2];
6✔
470
    // For output C[M, N] in row-major: ldc = N
471
    auto ldc = this->n();
6✔
472

473
    // Add GEMM node: C = alpha * A * B + beta * C
474
    // With alpha = 1.0, beta = 0.0: C = A * B
475
    auto& gemm_node = builder.add_library_node<blas::GEMMNode>(
6✔
476
        gemm_block,
6✔
477
        debug_info(),
6✔
478
        blas::ImplementationType_BLAS,
6✔
479
        precision,
6✔
480
        blas::BLAS_Layout::RowMajor,
6✔
481
        blas::BLAS_Transpose::No, // trans_a
6✔
482
        blas::BLAS_Transpose::No, // trans_b
6✔
483
        this->m(),
6✔
484
        this->n(),
6✔
485
        this->k(),
6✔
486
        lda,
6✔
487
        ldb,
6✔
488
        ldc
6✔
489
    );
6✔
490

491
    auto& a_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_a, debug_info());
6✔
492
    auto& b_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_b, debug_info());
6✔
493
    auto& c_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_c, debug_info());
6✔
494

495
    auto& c_access_ref_out = builder.add_access(gemm_block, ref_name_c, debug_info());
6✔
496

497
    // Create alpha and beta constants
498
    auto& alpha_const = builder.add_constant(gemm_block, "1.0", scalar_type, debug_info());
6✔
499
    auto& beta_const = builder.add_constant(gemm_block, "0.0", scalar_type, debug_info());
6✔
500

501
    // Connect memlets with batch offsets
502
    // Input A with offset
503
    builder.add_computational_memlet(
6✔
504
        gemm_block, a_access_ref_in_gemm, gemm_node, "__A", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
6✔
505
    );
6✔
506
    // Input B with offset
507
    builder.add_computational_memlet(
6✔
508
        gemm_block, b_access_ref_in_gemm, gemm_node, "__B", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
6✔
509
    );
6✔
510
    // Input C (for beta * C, but beta=0 so just needs to be connected)
511
    builder.add_computational_memlet(
6✔
512
        gemm_block, c_access_ref_in_gemm, gemm_node, "__C", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
6✔
513
    );
6✔
514
    // Alpha constant
515
    builder.add_computational_memlet(gemm_block, alpha_const, gemm_node, "__alpha", {}, scalar_type, debug_info());
6✔
516
    // Beta constant
517
    builder.add_computational_memlet(gemm_block, beta_const, gemm_node, "__beta", {}, scalar_type, debug_info());
6✔
518
    // Output C
519
    builder.add_computational_memlet(
6✔
520
        gemm_block, gemm_node, "__C", c_access_ref_out, {}, ::sdfg::types::Pointer(scalar_type), debug_info()
6✔
521
    );
6✔
522

523
    // Free copies if we made them
524
    if (copy_name_a != input_node_a.data()) {
6✔
525
        free_after_copy(copy_name_a, builder, new_sequence);
2✔
526
    }
2✔
527
    if (copy_name_b != input_node_b.data()) {
6✔
528
        free_after_copy(copy_name_b, builder, new_sequence);
1✔
529
    }
1✔
530

531
    // Remove the original nodes
532
    builder.remove_memlet(block, *iedge_a);
6✔
533
    builder.remove_memlet(block, *iedge_b);
6✔
534
    builder.remove_memlet(block, oedge);
6✔
535
    if (&input_node_a != &input_node_b) {
6✔
536
        builder.remove_node(block, input_node_a);
6✔
537
    }
6✔
538
    builder.remove_node(block, input_node_b);
6✔
539
    builder.remove_node(block, output_node);
6✔
540
    builder.remove_node(block, *this);
6✔
541
    builder.remove_child(parent, index + 1);
6✔
542

543
    return true;
6✔
544
}
7✔
545

UNCOV
546
nlohmann::json MatMulNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
547
    const MatMulNode& matmul_node = static_cast<const MatMulNode&>(library_node);
×
548
    nlohmann::json j;
×
549

UNCOV
550
    j["code"] = matmul_node.code().value();
×
551

UNCOV
552
    serializer::JSONSerializer serializer;
×
553

UNCOV
554
    j["shape_a"] = nlohmann::json::array();
×
UNCOV
555
    for (auto& dim : matmul_node.shape_a()) {
×
UNCOV
556
        j["shape_a"].push_back(serializer.expression(dim));
×
UNCOV
557
    }
×
558

UNCOV
559
    j["shape_b"] = nlohmann::json::array();
×
UNCOV
560
    for (auto& dim : matmul_node.shape_b()) {
×
UNCOV
561
        j["shape_b"].push_back(serializer.expression(dim));
×
UNCOV
562
    }
×
563

UNCOV
564
    j["strides_a"] = nlohmann::json::array();
×
UNCOV
565
    for (auto& stride : matmul_node.strides_a()) {
×
UNCOV
566
        j["strides_a"].push_back(serializer.expression(stride));
×
UNCOV
567
    }
×
568

UNCOV
569
    j["strides_b"] = nlohmann::json::array();
×
UNCOV
570
    for (auto& stride : matmul_node.strides_b()) {
×
UNCOV
571
        j["strides_b"].push_back(serializer.expression(stride));
×
UNCOV
572
    }
×
573

UNCOV
574
    j["offset_a"] = serializer.expression(matmul_node.offset_a());
×
UNCOV
575
    j["offset_b"] = serializer.expression(matmul_node.offset_b());
×
576

UNCOV
577
    return j;
×
UNCOV
578
}
×
579

580
data_flow::LibraryNode& MatMulNodeSerializer::deserialize(
581
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
UNCOV
582
) {
×
UNCOV
583
    assert(j.contains("element_id"));
×
UNCOV
584
    assert(j.contains("code"));
×
UNCOV
585
    assert(j.contains("debug_info"));
×
UNCOV
586
    assert(j.contains("shape_a"));
×
UNCOV
587
    assert(j.contains("shape_b"));
×
588

UNCOV
589
    symbolic::MultiExpression shape_a;
×
UNCOV
590
    for (const auto& dim : j["shape_a"]) {
×
UNCOV
591
        shape_a.push_back(symbolic::parse(dim.get<std::string>()));
×
UNCOV
592
    }
×
593

UNCOV
594
    symbolic::MultiExpression shape_b;
×
UNCOV
595
    for (const auto& dim : j["shape_b"]) {
×
UNCOV
596
        shape_b.push_back(symbolic::parse(dim.get<std::string>()));
×
UNCOV
597
    }
×
598

UNCOV
599
    symbolic::MultiExpression strides_a;
×
UNCOV
600
    if (j.contains("strides_a")) {
×
UNCOV
601
        for (const auto& stride : j["strides_a"]) {
×
UNCOV
602
            strides_a.push_back(symbolic::parse(stride.get<std::string>()));
×
UNCOV
603
        }
×
UNCOV
604
    }
×
605

UNCOV
606
    symbolic::MultiExpression strides_b;
×
UNCOV
607
    if (j.contains("strides_b")) {
×
UNCOV
608
        for (const auto& stride : j["strides_b"]) {
×
UNCOV
609
            strides_b.push_back(symbolic::parse(stride.get<std::string>()));
×
UNCOV
610
        }
×
UNCOV
611
    }
×
612

UNCOV
613
    symbolic::Expression offset_a = symbolic::integer(0);
×
UNCOV
614
    if (j.contains("offset_a")) {
×
UNCOV
615
        offset_a = symbolic::parse(j["offset_a"].get<std::string>());
×
UNCOV
616
    }
×
617

UNCOV
618
    symbolic::Expression offset_b = symbolic::integer(0);
×
UNCOV
619
    if (j.contains("offset_b")) {
×
UNCOV
620
        offset_b = symbolic::parse(j["offset_b"].get<std::string>());
×
UNCOV
621
    }
×
622

UNCOV
623
    sdfg::serializer::JSONSerializer serializer;
×
UNCOV
624
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
625

UNCOV
626
    return builder
×
UNCOV
627
        .add_library_node<MatMulNode>(parent, debug_info, shape_a, shape_b, strides_a, strides_b, offset_a, offset_b);
×
UNCOV
628
}
×
629

630
} // namespace tensor
631
} // namespace math
632
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc