• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 26521975951

27 May 2026 03:45PM UTC coverage: 60.869% (-0.02%) from 60.886%
26521975951

push

github

web-flow
Libnode ptr edges (#719)

Migrating SDFGs to treat pointers as inputs to libNodes / Calls as scalars.
A pointer will only appear in an output edge if its actually returned from the function (like malloc).

* Stdlib, Blas and Tensor Matmul nodes were migrated to this new format. Other, currently transitory Tensor Nodes are not yet migrated.
* DOCC version was bumped to incorporate previous docc-llvm versions (up to 0.4.0) that had been counted separately.
! Until all passes consider the use / leak of pointers as uncertainty / hiding potential writes, TensorNodes are declared as general side-effect.
* Lots of utility functions to centralize the creation (and edges) of various libNodes that needed to be changed.
* Fixed & unified docc paths across python and llvm front-ends.
* Skip BlockFusion test that fails to its libNodes currently having side effects
~ Prevent a crash in DotViz when using symbolic offsets into structs
* Removing old ConstProp pass, it is not safe for the new pointer representation and should not be all too critical

961 of 1749 new or added lines in 52 files covered. (54.95%)

87 existing lines in 28 files now uncovered.

35225 of 57870 relevant lines covered (60.87%)

11046.32 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

48.05
/sdfg/src/data_flow/library_nodes/math/tensor/matmul_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/matmul_node.h"
2
#include <cstddef>
3
#include <string>
4

5
#include "sdfg/analysis/scope_analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
8
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
9
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
10
#include "sdfg/data_flow/tasklet.h"
11
#include "sdfg/element.h"
12
#include "sdfg/exceptions.h"
13
#include "sdfg/structured_control_flow/control_flow_node.h"
14
#include "sdfg/structured_control_flow/map.h"
15
#include "sdfg/structured_control_flow/sequence.h"
16
#include "sdfg/symbolic/symbolic.h"
17
#include "sdfg/types/pointer.h"
18
#include "sdfg/types/scalar.h"
19
#include "sdfg/types/tensor.h"
20
#include "sdfg/types/type.h"
21
#include "sdfg/types/utils.h"
22

23
namespace sdfg {
24
namespace math {
25
namespace tensor {
26

UNCOV
27
bool MatMulNode::has_basic_strides(symbolic::MultiExpression shape, symbolic::MultiExpression strides) {
×
UNCOV
28
    auto basic_strides = types::Tensor::strides_from_shape(shape);
×
UNCOV
29
    if (basic_strides.size() != strides.size()) {
×
30
        return false;
×
31
    }
×
UNCOV
32
    for (size_t i = 0; i < strides.size(); i++) {
×
UNCOV
33
        if (!symbolic::eq(basic_strides[i], strides[i])) {
×
34
            return false;
×
35
        }
×
UNCOV
36
    }
×
UNCOV
37
    return true;
×
UNCOV
38
}
×
39

40
bool MatMulNode::has_transposed_strides(symbolic::MultiExpression shape, symbolic::MultiExpression strides) {
×
41
    if (shape.size() < 2) {
×
42
        return false;
×
43
    }
×
44
    symbolic::MultiExpression new_shape;
×
45
    new_shape.reserve(shape.size());
×
46
    for (size_t i = 0; i < shape.size() - 2; i++) {
×
47
        new_shape.push_back(shape[i]);
×
48
    }
×
49
    new_shape.push_back(shape[shape.size() - 1]);
×
50
    new_shape.push_back(shape[shape.size() - 2]);
×
51
    symbolic::MultiExpression transposed_strides(strides);
×
52
    transposed_strides[strides.size() - 2] = strides[strides.size() - 1];
×
53
    transposed_strides[strides.size() - 1] = strides[strides.size() - 2];
×
54
    return MatMulNode::has_basic_strides(new_shape, transposed_strides);
×
55
}
×
56

57
MatMulNode::MatMulNode(
58
    size_t element_id,
59
    const DebugInfo& debug_info,
60
    const graph::Vertex vertex,
61
    data_flow::DataFlowGraph& parent,
62
    const TensorLayout& layout_a,
63
    const TensorLayout& layout_b,
64
    types::PrimitiveType quantization
65
)
66
    : TensorNode(
5✔
67
          element_id,
5✔
68
          debug_info,
5✔
69
          vertex,
5✔
70
          parent,
5✔
71
          LibraryNodeType_MatMul,
5✔
72
          {},
5✔
73
          {"Y", "A", "B"},
5✔
74
          data_flow::ImplementationType_NONE
5✔
75
      ),
5✔
76
      fixed_quantization_(quantization), layout_a_(layout_a), layout_b_(layout_b) {
5✔
77
    if (layout_a.dims() < 2) {
5✔
78
        throw std::invalid_argument("MatMulNode: Input A must have at least 2 dimensions");
×
79
    }
×
80
    if (layout_b.dims() < 2) {
5✔
81
        throw std::invalid_argument("MatMulNode: Input B must have at least 2 dimensions");
×
82
    }
×
83
}
5✔
84

85
symbolic::Expression MatMulNode::m() const {
8✔
86
    // M is the second-to-last dimension of A
87
    return layout_a_.get_dim_innermost(1);
8✔
88
}
8✔
89

90
symbolic::Expression MatMulNode::n() const {
12✔
91
    // N is the last dimension of B
92
    return layout_b_.get_dim_innermost(0);
12✔
93
}
12✔
94

95
symbolic::Expression MatMulNode::k() const {
7✔
96
    // K is the last dimension of A (and second-to-last of B)
97
    return layout_a_.get_dim_innermost(0);
7✔
98
}
7✔
99

NEW
100
const TensorLayout& MatMulNode::layout_a() const { return layout_a_; }
×
101

NEW
102
const TensorLayout& MatMulNode::layout_b() const { return layout_b_; }
×
103

104
void MatMulNode::validate(const Function& function) const {
5✔
105
    TensorNode::validate(function);
5✔
106

107
    auto& graph = this->get_parent();
5✔
108

109
    // Check that we have exactly 2 inputs and 1 output
110
    if (graph.in_degree(*this) != 3) {
5✔
NEW
111
        throw InvalidSDFGException("MatMulNode: Expected exactly 3 inputs (Y, A, B)");
×
UNCOV
112
    }
×
113
    if (graph.out_degree(*this) != 0) {
5✔
NEW
114
        throw InvalidSDFGException("MatMulNode: Expected no outputs");
×
UNCOV
115
    }
×
116

117
    // Validate K dimension matches between A and B
118
    auto k_a = layout_a_.get_dim_innermost(0);
5✔
119
    auto k_b = layout_b_.get_dim_innermost(1);
5✔
120
    if (!symbolic::eq(k_a, k_b)) {
5✔
121
        throw InvalidSDFGException(
×
122
            "MatMulNode: K dimension mismatch. A has K=" + k_a->__str__() + ", B has K=" + k_b->__str__()
×
123
        );
×
124
    }
×
125
}
5✔
126

127
symbolic::SymbolSet MatMulNode::symbols() const {
1✔
128
    symbolic::SymbolSet syms;
1✔
129
    layout_a_.collect_symbols(syms);
1✔
130
    layout_b_.collect_symbols(syms);
1✔
131
    return syms;
1✔
132
}
1✔
133

134
void MatMulNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
NEW
135
    layout_a_.replace_symbols(old_expression, new_expression);
×
NEW
136
    layout_b_.replace_symbols(old_expression, new_expression);
×
UNCOV
137
}
×
138

139
std::unique_ptr<data_flow::DataFlowNode> MatMulNode::
140
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
141
    return std::unique_ptr<data_flow::DataFlowNode>(
×
NEW
142
        new MatMulNode(element_id, debug_info(), vertex, parent, layout_a_, layout_b_, fixed_quantization_)
×
NEW
143
    );
×
NEW
144
}
×
145

NEW
146
types::PrimitiveType MatMulNode::fixed_quantization() const { return fixed_quantization_; }
×
147

NEW
148
types::PrimitiveType MatMulNode::quantization(const data_flow::DataFlowGraph& data_flow_graph) const {
×
NEW
149
    if (fixed_quantization_ != QUANTIZATION_MATCH_INPUTS) {
×
NEW
150
        return fixed_quantization_;
×
NEW
151
    } else {
×
NEW
152
        return this->primitive_type(data_flow_graph);
×
NEW
153
    }
×
NEW
154
}
×
155

156
std::optional<types::PrimitiveType> MatMulNode::uniform_quantization(const data_flow::DataFlowGraph& data_flow_graph
157
) const {
5✔
158
    if (fixed_quantization_ != QUANTIZATION_MATCH_INPUTS) {
5✔
NEW
159
        auto inferred = this->primitive_type(data_flow_graph);
×
NEW
160
        if (inferred == fixed_quantization_) {
×
NEW
161
            return fixed_quantization_;
×
NEW
162
        } else {
×
NEW
163
            return std::nullopt;
×
NEW
164
        }
×
165
    } else {
5✔
166
        return this->primitive_type(data_flow_graph);
5✔
167
    }
5✔
168
}
5✔
169

170
std::string MatMulNode::toStr() const {
×
171
    std::stringstream ss;
×
172
    ss << "MatMul(";
×
NEW
173
    ss << types::primitive_type_to_string(fixed_quantization_) << ", ";
×
NEW
174
    ss << "A: " << layout_a_;
×
NEW
175
    ss << ", B: " << layout_b_;
×
176
    ss << ")";
×
177
    return ss.str();
×
178
}
×
179

NEW
180
symbolic::Expression MatMulNode::flop() const {
×
NEW
181
    auto res_elems = symbolic::mul(this->m(), this->n());
×
NEW
182
    auto k = this->k();
×
183

NEW
184
    auto mm_mul_ops = symbolic::mul(res_elems, k);
×
NEW
185
    auto mm_sum_ops = symbolic::mul(res_elems, symbolic::sub(k, symbolic::one()));
×
186

NEW
187
    auto mul_ops = mm_mul_ops;
×
NEW
188
    auto add_ops = mm_sum_ops;
×
NEW
189
    auto per_mat = symbolic::add(mul_ops, add_ops);
×
NEW
190
    int a_dims = layout_a_.dims();
×
NEW
191
    int b_dims = layout_b_.dims();
×
NEW
192
    if (a_dims > 2 || b_dims > 2) {
×
NEW
193
        std::vector<symbolic::Expression> factors{per_mat};
×
NEW
194
        auto max_dims = std::max(a_dims, b_dims);
×
NEW
195
        for (int i = 2; i < max_dims; ++i) {
×
NEW
196
            symbolic::Expression dim_a, dim_b;
×
NEW
197
            if (i < a_dims) {
×
NEW
198
                dim_a = layout_a_.get_dim_innermost(i);
×
NEW
199
            }
×
NEW
200
            if (i < b_dims) {
×
NEW
201
                dim_b = layout_b_.get_dim_innermost(i);
×
NEW
202
            }
×
NEW
203
            if (dim_a.is_null() & !dim_b.is_null()) {
×
NEW
204
                factors.push_back(dim_b);
×
NEW
205
            } else if (!dim_a.is_null() & dim_b.is_null()) {
×
NEW
206
                factors.push_back(dim_a);
×
NEW
207
            } else if (!dim_a.is_null() & !dim_b.is_null()) {
×
NEW
208
                if (!symbolic::eq(dim_a, dim_b)) {
×
NEW
209
                    throw InvalidSDFGException(
×
NEW
210
                        "Batch dimension " + std::to_string(i) + " mismatch between A and B. A has " +
×
NEW
211
                        dim_a->__str__() + ", B has " + dim_b->__str__()
×
NEW
212
                    );
×
NEW
213
                } else {
×
NEW
214
                    factors.push_back(dim_a);
×
NEW
215
                }
×
NEW
216
            } else {
×
NEW
217
                return SymEngine::null;
×
NEW
218
            }
×
NEW
219
        }
×
NEW
220
        return SymEngine::mul(factors);
×
NEW
221
    } else {
×
NEW
222
        return per_mat;
×
NEW
223
    }
×
NEW
224
}
×
225

226
void free_after_copy(
227
    const std::string& copy_name, builder::StructuredSDFGBuilder& builder, structured_control_flow::Sequence& parent
228
) {
×
229
    auto& block = builder.add_block(parent, {}, DebugInfo());
×
230
    auto& access_in = builder.add_access(block, copy_name);
×
231
    auto& free_node = builder.add_library_node<stdlib::FreeNode>(block, DebugInfo());
×
232
    builder.add_computational_memlet(
×
233
        block, access_in, free_node, "_ptr", {}, types::Pointer(types::Scalar(types::PrimitiveType::Void))
×
234
    );
×
235
}
×
236

237
bool MatMulNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
238
    auto& dataflow = this->get_parent();
5✔
239
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
5✔
240

241
    if (dataflow.in_degree(*this) != 3 || dataflow.out_degree(*this) != 0) {
5✔
242
        return false;
×
243
    }
×
244

245
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
246
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
5✔
247
    int index = parent.index(block);
5✔
248
    auto& transition = parent.at(index).second;
5✔
249

250
    // Get input and output edges
251
    auto iedges = dataflow.in_edges_by_connector(*this);
5✔
252
    if (iedges.size() != 3) {
5✔
253
        return false;
×
254
    }
×
255
    auto* iedge_y = iedges.at(Y_INPUT_IDX);
5✔
256
    auto* iedge_a = iedges.at(A_INPUT_IDX);
5✔
257
    auto* iedge_b = iedges.at(B_INPUT_IDX);
5✔
258

259
    // Check if legal - access nodes must not have other connections
260
    auto& input_node_a = static_cast<data_flow::AccessNode&>(iedge_a->src());
5✔
261
    auto& input_node_b = static_cast<data_flow::AccessNode&>(iedge_b->src());
5✔
262
    auto& output_ptr = static_cast<data_flow::AccessNode&>(iedge_y->src());
5✔
263

264
    if (dataflow.in_degree(input_node_a) != 0 || dataflow.in_degree(input_node_b) != 0 ||
5✔
265
        dataflow.in_degree(output_ptr) != 0) {
5✔
266
        return false;
×
267
    }
×
268

269
    // Determine BLAS precision from primitive type
270
    auto prim_type = this->uniform_quantization(dataflow);
5✔
271
    if (!prim_type) {
5✔
NEW
272
        return false;
×
NEW
273
    }
×
274
    blas::BLAS_Precision precision;
5✔
275
    switch (prim_type.value()) {
5✔
276
        case types::PrimitiveType::Half:
×
277
            precision = blas::BLAS_Precision::h;
×
278
            break;
×
279
        case types::PrimitiveType::Float:
3✔
280
            precision = blas::BLAS_Precision::s;
3✔
281
            break;
3✔
282
        case types::PrimitiveType::Double:
1✔
283
            precision = blas::BLAS_Precision::d;
1✔
284
            break;
1✔
285
        default:
1✔
286
            // GEMM only supports floating point types, fall back to naive expansion
287
            return false;
1✔
288
    };
5✔
289

290
    // Add new graph after the current block
291
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
4✔
292

293
    auto copy_name_a = input_node_a.data();
4✔
294
    auto copy_name_b = input_node_b.data();
4✔
295

296
    // Check if A and B have basic strides and whether they are transposed in the last dimension
297
    blas::BLAS_Transpose trans_a, trans_b;
4✔
298
    if (layout_a_.has_linear_accesses_no_padding()) {
4✔
299
        trans_a = blas::BLAS_Transpose::No;
4✔
300
    } else if (layout_a_.has_transposed_strides_no_padding()) {
4✔
301
        trans_a = blas::BLAS_Transpose::Trans;
×
302
    } else {
×
303
        trans_a = blas::BLAS_Transpose::No;
×
304
        throw InvalidSDFGException("A must be in c-order");
×
305
    }
×
306
    if (layout_b_.has_linear_accesses_no_padding()) {
4✔
307
        trans_b = blas::BLAS_Transpose::No;
4✔
308
    } else if (layout_b_.has_transposed_strides_no_padding()) {
4✔
309
        trans_b = blas::BLAS_Transpose::Trans;
×
310
    } else {
×
311
        trans_b = blas::BLAS_Transpose::No;
×
312
        throw InvalidSDFGException("B must be in c-order");
×
313
    }
×
314

315
    // Create maps for batch dimensions and M, N dimensions
316
    structured_control_flow::Sequence* last_scope = &new_sequence;
4✔
317
    structured_control_flow::Map* last_map = nullptr;
4✔
318
    symbolic::MultiExpression batch_vars;
4✔
319

320
    // Compute batch dimensions (all except last 2)
321
    size_t batch_dims_a = layout_a_.dims() - 2;
4✔
322
    size_t batch_dims_b = layout_b_.dims() - 2;
4✔
323
    size_t max_batch_dims = std::max(batch_dims_a, batch_dims_b);
4✔
324

325
    // Create maps for batch dimensions (using broadcasting)
326
    for (size_t i = 0; i < max_batch_dims; ++i) {
5✔
327
        std::string indvar_str = builder.find_new_name("_b");
1✔
328
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
1✔
329

330
        auto indvar = symbolic::symbol(indvar_str);
1✔
331
        auto init = symbolic::zero();
1✔
332
        auto update = symbolic::add(indvar, symbolic::one());
1✔
333

334
        // Determine the bound for this batch dimension (max of A and B for broadcasting)
335
        symbolic::Expression bound;
1✔
336
        size_t a_idx = batch_dims_a >= (max_batch_dims - i) ? i - (max_batch_dims - batch_dims_a) : SIZE_MAX;
1✔
337
        size_t b_idx = batch_dims_b >= (max_batch_dims - i) ? i - (max_batch_dims - batch_dims_b) : SIZE_MAX;
1✔
338

339
        if (a_idx != SIZE_MAX && b_idx != SIZE_MAX) {
1✔
340
            // Both have this dimension - they should be equal or one should be 1 (broadcasting)
341
            bound = layout_a_.get_dim(a_idx); // Assume they match or broadcasting is handled
1✔
342
        } else if (a_idx != SIZE_MAX) {
1✔
NEW
343
            bound = layout_a_.get_dim(a_idx);
×
344
        } else {
×
NEW
345
            bound = layout_b_.get_dim(b_idx);
×
346
        }
×
347

348
        auto condition = symbolic::Lt(indvar, bound);
1✔
349
        last_map = &builder.add_map(
1✔
350
            *last_scope,
1✔
351
            indvar,
1✔
352
            condition,
1✔
353
            init,
1✔
354
            update,
1✔
355
            structured_control_flow::ScheduleType_Sequential::create(),
1✔
356
            {},
1✔
357
            block.debug_info()
1✔
358
        );
1✔
359
        last_scope = &last_map->root();
1✔
360
        batch_vars.push_back(indvar);
1✔
361
    }
1✔
362

363
    auto& ref_block = builder.add_block(*last_scope, {}, block.debug_info());
4✔
364

365
    auto scalar_type = types::Scalar(prim_type.value());
4✔
366

367
    // Compute offsets for this batch iteration
368
    // For A: base_offset_a = offset_a + sum_i(batch_idx_i * batch_stride_a_i)
369
    symbolic::Expression a_batch_offset = layout_a_.offset();
4✔
370
    for (size_t i = 0; i < batch_dims_a; ++i) {
5✔
371
        size_t batch_idx = max_batch_dims - batch_dims_a + i;
1✔
372
        a_batch_offset = symbolic::add(a_batch_offset, symbolic::mul(batch_vars[batch_idx], layout_a_.get_stride(i)));
1✔
373
    }
1✔
374

375
    // For B: base_offset_b = offset_b + sum_i(batch_idx_i * batch_stride_b_i)
376
    symbolic::Expression b_batch_offset = layout_b_.offset();
4✔
377
    for (size_t i = 0; i < batch_dims_b; ++i) {
5✔
378
        size_t batch_idx = max_batch_dims - batch_dims_b + i;
1✔
379
        b_batch_offset = symbolic::add(b_batch_offset, symbolic::mul(batch_vars[batch_idx], layout_b_.get_stride(i)));
1✔
380
    }
1✔
381

382
    // Compute output batch offset (same as batch_vars pattern for Y)
383
    symbolic::Expression c_batch_offset = symbolic::integer(0);
4✔
384
    for (size_t i = 0; i < batch_vars.size(); ++i) {
5✔
385
        // Output has shape [batch..., M, N] with row-major strides
386
        // Stride for batch dim i is: M * N * product of remaining batch dims
387
        symbolic::Expression c_stride = symbolic::mul(this->m(), this->n());
1✔
388
        for (size_t j = i + 1; j < batch_vars.size(); ++j) {
1✔
389
            // Multiply by subsequent batch dimensions
390
            if (j < batch_dims_a) {
×
NEW
391
                c_stride = symbolic::mul(c_stride, layout_a_.get_dim(j));
×
392
            } else if (j - batch_dims_a < batch_dims_b) {
×
NEW
393
                c_stride = symbolic::mul(c_stride, layout_b_.get_dim(j - batch_dims_a));
×
394
            }
×
395
        }
×
396
        c_batch_offset = symbolic::add(c_batch_offset, symbolic::mul(batch_vars[i], c_stride));
1✔
397
    }
1✔
398

399
    // Create access nodes
400
    auto& a_access = builder.add_access(ref_block, copy_name_a, debug_info());
4✔
401
    auto& b_access = builder.add_access(ref_block, copy_name_b, debug_info());
4✔
402
    auto& c_access_in = builder.add_access(ref_block, output_ptr.data(), debug_info());
4✔
403

404
    std::string ref_name_a = builder.find_new_name(copy_name_a + "_ref");
4✔
405
    builder.add_container(ref_name_a, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
406
    auto& a_access_ref = builder.add_access(ref_block, ref_name_a, debug_info());
4✔
407
    std::string ref_name_b = builder.find_new_name(copy_name_b + "_ref");
4✔
408
    builder.add_container(ref_name_b, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
409
    auto& b_access_ref = builder.add_access(ref_block, ref_name_b, debug_info());
4✔
410
    std::string ref_name_c = builder.find_new_name(output_ptr.data() + "_ref");
4✔
411
    builder.add_container(ref_name_c, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
412
    auto& c_access_ref_in = builder.add_access(ref_block, ref_name_c, debug_info());
4✔
413

414
    builder.add_reference_memlet(
4✔
415
        ref_block, a_access, a_access_ref, {a_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
416
    );
4✔
417
    builder.add_reference_memlet(
4✔
418
        ref_block, b_access, b_access_ref, {b_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
419
    );
4✔
420
    builder.add_reference_memlet(
4✔
421
        ref_block, c_access_in, c_access_ref_in, {c_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
422
    );
4✔
423

424
    // Create block with GEMM library node
425
    auto& gemm_block = builder.add_block(*last_scope, {}, block.debug_info());
4✔
426

427
    // Leading dimensions: stride of the row dimension (second-to-last dim)
428
    symbolic::Expression lda, ldb;
4✔
429
    if (trans_a == blas::BLAS_Transpose::No) {
4✔
430
        // For row-major A [m * k] -> lda = k
431
        lda = layout_a_.get_stride_innermost(1);
4✔
432
    } else {
4✔
433
        // For row-major A [m * k] -> lda = m
NEW
434
        lda = layout_a_.get_stride_innermost(0);
×
435
    }
×
436
    if (trans_b == blas::BLAS_Transpose::No) {
4✔
437
        // For row-major B [k * n] -> ldb = n
438
        ldb = layout_b_.get_stride_innermost(1);
4✔
439
    } else {
4✔
440
        // For row-major B [k * n] -> ldb = k
NEW
441
        ldb = layout_b_.get_stride_innermost(0);
×
442
    }
×
443
    // For row-major C [m * n] -> ldc = n
444
    auto ldc = this->n();
4✔
445

446
    // Add GEMM node: C = alpha * A * B + beta * C
447
    // With alpha = 1.0, beta = 0.0: C = A * B
448
    auto& gemm_node = builder.add_library_node<blas::GEMMNode>(
4✔
449
        gemm_block,
4✔
450
        debug_info(),
4✔
451
        blas::ImplementationType_BLAS,
4✔
452
        precision,
4✔
453
        blas::BLAS_Layout::RowMajor,
4✔
454
        trans_a,
4✔
455
        trans_b,
4✔
456
        this->m(),
4✔
457
        this->n(),
4✔
458
        this->k(),
4✔
459
        lda,
4✔
460
        ldb,
4✔
461
        ldc
4✔
462
    );
4✔
463

464
    auto& a_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_a, debug_info());
4✔
465
    auto& b_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_b, debug_info());
4✔
466
    auto& c_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_c, debug_info());
4✔
467

468
    // Create alpha and beta constants
469
    auto& alpha_const = builder.add_constant(gemm_block, "1.0", scalar_type, debug_info());
4✔
470
    auto& beta_const = builder.add_constant(gemm_block, "0.0", scalar_type, debug_info());
4✔
471

472
    // Connect memlets with batch offsets
473
    // Input A with offset
474
    builder.add_computational_memlet(
4✔
475
        gemm_block, a_access_ref_in_gemm, gemm_node, "__A", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
476
    );
4✔
477
    // Input B with offset
478
    builder.add_computational_memlet(
4✔
479
        gemm_block, b_access_ref_in_gemm, gemm_node, "__B", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
480
    );
4✔
481
    // Input C (for beta * C, but beta=0 so just needs to be connected)
482
    builder.add_computational_memlet(
4✔
483
        gemm_block, c_access_ref_in_gemm, gemm_node, "__C", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
484
    );
4✔
485
    // Alpha constant
486
    builder.add_computational_memlet(gemm_block, alpha_const, gemm_node, "__alpha", {}, scalar_type, debug_info());
4✔
487
    // Beta constant
488
    builder.add_computational_memlet(gemm_block, beta_const, gemm_node, "__beta", {}, scalar_type, debug_info());
4✔
489

490
    // Free copies if we made them
491
    if (copy_name_a != input_node_a.data()) {
4✔
492
        free_after_copy(copy_name_a, builder, new_sequence);
×
493
    }
×
494
    if (copy_name_b != input_node_b.data()) {
4✔
495
        free_after_copy(copy_name_b, builder, new_sequence);
×
496
    }
×
497

498
    // Remove the original nodes
499
    builder.remove_memlet(block, *iedge_a);
4✔
500
    builder.remove_memlet(block, *iedge_b);
4✔
501
    builder.remove_memlet(block, *iedge_y);
4✔
502
    if (&input_node_a != &input_node_b) {
4✔
503
        builder.remove_node(block, input_node_a);
4✔
504
    }
4✔
505
    builder.remove_node(block, input_node_b);
4✔
506
    builder.remove_node(block, output_ptr);
4✔
507
    builder.remove_node(block, *this);
4✔
508
    builder.remove_child(parent, index + 1);
4✔
509

510
    return true;
4✔
511
}
4✔
512

513
nlohmann::json MatMulNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
514
    const MatMulNode& matmul_node = static_cast<const MatMulNode&>(library_node);
×
515
    nlohmann::json j;
×
516

517
    j["code"] = matmul_node.code().value();
×
518

519
    serializer::JSONSerializer serializer;
×
520

NEW
521
    matmul_node.layout_a().serialize_to_json(j["layout_a"]);
×
NEW
522
    matmul_node.layout_b().serialize_to_json(j["layout_b"]);
×
523

NEW
524
    j["result_quant"] = matmul_node.fixed_quantization();
×
525

526
    return j;
×
527
}
×
528

529
data_flow::LibraryNode& MatMulNodeSerializer::deserialize(
530
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
531
) {
×
532
    assert(j.contains("element_id"));
×
533
    assert(j.contains("code"));
×
534
    assert(j.contains("debug_info"));
×
535

NEW
536
    std::optional<TensorLayout> layout_a;
×
NEW
537
    std::optional<TensorLayout> layout_b;
×
NEW
538
    types::PrimitiveType quantization = QUANTIZATION_MATCH_INPUTS;
×
539

NEW
540
    auto layout_a_it = j.find("layout_a");
×
NEW
541
    if (layout_a_it != j.end()) {
×
NEW
542
        layout_a = TensorLayout::deserialize_from_json(*layout_a_it);
×
NEW
543
        layout_b = TensorLayout::deserialize_from_json(j.at("layout_b"));
×
544

NEW
545
    } else {
×
NEW
546
        assert(j.contains("shape_a"));
×
NEW
547
        assert(j.contains("shape_b"));
×
548

NEW
549
        symbolic::MultiExpression shape_a;
×
NEW
550
        for (const auto& dim : j["shape_a"]) {
×
NEW
551
            shape_a.push_back(symbolic::parse(dim.get<std::string>()));
×
552
        }
×
553

NEW
554
        symbolic::MultiExpression shape_b;
×
NEW
555
        for (const auto& dim : j["shape_b"]) {
×
NEW
556
            shape_b.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
557
        }
×
558

NEW
559
        symbolic::MultiExpression strides_a;
×
NEW
560
        if (j.contains("strides_a")) {
×
NEW
561
            for (const auto& stride : j["strides_a"]) {
×
NEW
562
                strides_a.push_back(symbolic::parse(stride.get<std::string>()));
×
NEW
563
            }
×
NEW
564
        }
×
565

NEW
566
        symbolic::MultiExpression strides_b;
×
NEW
567
        if (j.contains("strides_b")) {
×
NEW
568
            for (const auto& stride : j["strides_b"]) {
×
NEW
569
                strides_b.push_back(symbolic::parse(stride.get<std::string>()));
×
NEW
570
            }
×
NEW
571
        }
×
572

NEW
573
        symbolic::Expression offset_a = symbolic::integer(0);
×
NEW
574
        if (j.contains("offset_a")) {
×
NEW
575
            offset_a = symbolic::parse(j["offset_a"].get<std::string>());
×
NEW
576
        }
×
577

NEW
578
        symbolic::Expression offset_b = symbolic::integer(0);
×
NEW
579
        if (j.contains("offset_b")) {
×
NEW
580
            offset_b = symbolic::parse(j["offset_b"].get<std::string>());
×
581
        }
×
582

NEW
583
        layout_a = TensorLayout(shape_a, strides_a, offset_a);
×
NEW
584
        layout_b = TensorLayout(shape_b, strides_b, offset_b);
×
UNCOV
585
    }
×
586

NEW
587
    auto result_quant = j.find("result_quant");
×
NEW
588
    if (result_quant != j.end()) {
×
NEW
589
        quantization = result_quant->get<types::PrimitiveType>();
×
UNCOV
590
    }
×
591

592
    sdfg::serializer::JSONSerializer serializer;
×
593
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
594

NEW
595
    return builder.add_library_node<MatMulNode>(parent, debug_info, layout_a.value(), layout_b.value(), quantization);
×
596
}
×
597

598
} // namespace tensor
599
} // namespace math
600
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc