• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27007027060

05 Jun 2026 09:28AM UTC coverage: 61.275% (-0.02%) from 61.292%
27007027060

push

github

web-flow
Improve Quantization support on TensorNodes (#736)

* Added DataFlowGraph.find_standalone_exit() following the pattern of find_standalone_entry() to abstract away edge types.
* LibNodeDispatcher allows no missing inputs.
  ConvNode explicitly is configured whether it has a bias or not to solve for this.
* Fixed elementwise CMath node toStr()

---------

Co-authored-by: Moritz Timmer <25349452+Moehre2@users.noreply.github.com>

10 of 43 new or added lines in 8 files covered. (23.26%)

1 existing line in 1 file now uncovered.

35592 of 58086 relevant lines covered (61.27%)

11015.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

46.28
/sdfg/src/data_flow/library_nodes/math/tensor/matmul_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/matmul_node.h"
2
#include <cstddef>
3
#include <string>
4

5
#include "sdfg/analysis/scope_analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
8
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
9
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
10
#include "sdfg/data_flow/tasklet.h"
11
#include "sdfg/element.h"
12
#include "sdfg/exceptions.h"
13
#include "sdfg/structured_control_flow/control_flow_node.h"
14
#include "sdfg/structured_control_flow/map.h"
15
#include "sdfg/structured_control_flow/sequence.h"
16
#include "sdfg/symbolic/symbolic.h"
17
#include "sdfg/types/pointer.h"
18
#include "sdfg/types/scalar.h"
19
#include "sdfg/types/tensor.h"
20
#include "sdfg/types/type.h"
21
#include "sdfg/types/utils.h"
22

23
namespace sdfg {
24
namespace math {
25
namespace tensor {
26

27
bool MatMulNode::has_basic_strides(symbolic::MultiExpression shape, symbolic::MultiExpression strides) {
×
28
    auto basic_strides = types::Tensor::strides_from_shape(shape);
×
29
    if (basic_strides.size() != strides.size()) {
×
30
        return false;
×
31
    }
×
32
    for (size_t i = 0; i < strides.size(); i++) {
×
33
        if (!symbolic::eq(basic_strides[i], strides[i])) {
×
34
            return false;
×
35
        }
×
36
    }
×
37
    return true;
×
38
}
×
39

40
bool MatMulNode::has_transposed_strides(symbolic::MultiExpression shape, symbolic::MultiExpression strides) {
×
41
    if (shape.size() < 2) {
×
42
        return false;
×
43
    }
×
44
    symbolic::MultiExpression new_shape;
×
45
    new_shape.reserve(shape.size());
×
46
    for (size_t i = 0; i < shape.size() - 2; i++) {
×
47
        new_shape.push_back(shape[i]);
×
48
    }
×
49
    new_shape.push_back(shape[shape.size() - 1]);
×
50
    new_shape.push_back(shape[shape.size() - 2]);
×
51
    symbolic::MultiExpression transposed_strides(strides);
×
52
    transposed_strides[strides.size() - 2] = strides[strides.size() - 1];
×
53
    transposed_strides[strides.size() - 1] = strides[strides.size() - 2];
×
54
    return MatMulNode::has_basic_strides(new_shape, transposed_strides);
×
55
}
×
56

57
MatMulNode::MatMulNode(
58
    size_t element_id,
59
    const DebugInfo& debug_info,
60
    const graph::Vertex vertex,
61
    data_flow::DataFlowGraph& parent,
62
    const TensorLayout& layout_a,
63
    const TensorLayout& layout_b,
64
    QuantizationType quantization,
65
    const data_flow::ImplementationType& impl_type
66
)
67
    : TensorNode(element_id, debug_info, vertex, parent, LibraryNodeType_MatMul, {}, {"Y", "A", "B"}, impl_type),
5✔
68
      fixed_quantization_(quantization), layout_a_(layout_a), layout_b_(layout_b) {
5✔
69
    if (layout_a.dims() < 2) {
5✔
70
        throw std::invalid_argument("MatMulNode: Input A must have at least 2 dimensions");
×
71
    }
×
72
    if (layout_b.dims() < 2) {
5✔
73
        throw std::invalid_argument("MatMulNode: Input B must have at least 2 dimensions");
×
74
    }
×
75
}
5✔
76

77
symbolic::Expression MatMulNode::m() const {
8✔
78
    // M is the second-to-last dimension of A
79
    return layout_a_.get_dim_innermost(1);
8✔
80
}
8✔
81

82
symbolic::Expression MatMulNode::n() const {
12✔
83
    // N is the last dimension of B
84
    return layout_b_.get_dim_innermost(0);
12✔
85
}
12✔
86

87
symbolic::Expression MatMulNode::k() const {
7✔
88
    // K is the last dimension of A (and second-to-last of B)
89
    return layout_a_.get_dim_innermost(0);
7✔
90
}
7✔
91

92
const TensorLayout& MatMulNode::layout_a() const { return layout_a_; }
×
93

94
const TensorLayout& MatMulNode::layout_b() const { return layout_b_; }
×
95

96
void MatMulNode::validate(const Function& function) const {
5✔
97
    TensorNode::validate(function);
5✔
98

99
    auto& graph = this->get_parent();
5✔
100

101
    // Check that we have exactly 2 inputs and 1 output
102
    if (graph.in_degree(*this) != 3) {
5✔
103
        throw InvalidSDFGException("MatMulNode: Expected exactly 3 inputs (Y, A, B)");
×
104
    }
×
105
    if (graph.out_degree(*this) != 0) {
5✔
106
        throw InvalidSDFGException("MatMulNode: Expected no outputs");
×
107
    }
×
108

109
    // Validate K dimension matches between A and B
110
    auto k_a = layout_a_.get_dim_innermost(0);
5✔
111
    auto k_b = layout_b_.get_dim_innermost(1);
5✔
112
    if (!symbolic::eq(k_a, k_b)) {
5✔
113
        throw InvalidSDFGException(
×
114
            "MatMulNode: K dimension mismatch. A has K=" + k_a->__str__() + ", B has K=" + k_b->__str__()
×
115
        );
×
116
    }
×
117
}
5✔
118

119
symbolic::SymbolSet MatMulNode::symbols() const {
1✔
120
    symbolic::SymbolSet syms;
1✔
121
    layout_a_.collect_symbols(syms);
1✔
122
    layout_b_.collect_symbols(syms);
1✔
123
    return syms;
1✔
124
}
1✔
125

126
void MatMulNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
127
    layout_a_.replace_symbols(old_expression, new_expression);
×
128
    layout_b_.replace_symbols(old_expression, new_expression);
×
129
}
×
130

131
std::unique_ptr<data_flow::DataFlowNode> MatMulNode::
132
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
133
    return std::unique_ptr<data_flow::DataFlowNode>(new MatMulNode(
×
134
        element_id, debug_info(), vertex, parent, layout_a_, layout_b_, fixed_quantization_, implementation_type_
×
135
    ));
×
136
}
×
137

138
types::PrimitiveType MatMulNode::fixed_quantization() const { return fixed_quantization_; }
×
139

NEW
140
void MatMulNode::set_fixed_quantization(const QuantizationType quant) { fixed_quantization_ = quant; }
×
141

142
types::PrimitiveType MatMulNode::quantization(const data_flow::DataFlowGraph& data_flow_graph) const {
×
143
    if (fixed_quantization_ != QUANTIZATION_MATCH_INPUTS) {
×
144
        return fixed_quantization_;
×
145
    } else {
×
146
        return this->primitive_type(data_flow_graph);
×
147
    }
×
148
}
×
149

150
std::optional<types::PrimitiveType> MatMulNode::uniform_quantization(const data_flow::DataFlowGraph& data_flow_graph
151
) const {
5✔
152
    if (fixed_quantization_ != QUANTIZATION_MATCH_INPUTS) {
5✔
153
        auto inferred = this->primitive_type(data_flow_graph);
×
154
        if (inferred == fixed_quantization_) {
×
155
            return fixed_quantization_;
×
156
        } else {
×
157
            return std::nullopt;
×
158
        }
×
159
    } else {
5✔
160
        return this->primitive_type(data_flow_graph);
5✔
161
    }
5✔
162
}
5✔
163

164
std::string MatMulNode::toStr() const {
×
165
    std::stringstream ss;
×
166
    ss << "MatMul(";
×
167
    ss << types::primitive_type_to_string(fixed_quantization_) << ", ";
×
168
    ss << "A: " << layout_a_;
×
169
    ss << ", B: " << layout_b_;
×
170
    ss << ")";
×
171
    return ss.str();
×
172
}
×
173

174
symbolic::Expression MatMulNode::flop() const {
×
175
    auto res_elems = symbolic::mul(this->m(), this->n());
×
176
    auto k = this->k();
×
177

178
    auto mm_mul_ops = symbolic::mul(res_elems, k);
×
179
    auto mm_sum_ops = symbolic::mul(res_elems, symbolic::sub(k, symbolic::one()));
×
180

181
    auto mul_ops = mm_mul_ops;
×
182
    auto add_ops = mm_sum_ops;
×
183
    auto per_mat = symbolic::add(mul_ops, add_ops);
×
184
    int a_dims = layout_a_.dims();
×
185
    int b_dims = layout_b_.dims();
×
186
    if (a_dims > 2 || b_dims > 2) {
×
187
        std::vector<symbolic::Expression> factors{per_mat};
×
188
        auto max_dims = std::max(a_dims, b_dims);
×
189
        for (int i = 2; i < max_dims; ++i) {
×
190
            symbolic::Expression dim_a, dim_b;
×
191
            if (i < a_dims) {
×
192
                dim_a = layout_a_.get_dim_innermost(i);
×
193
            }
×
194
            if (i < b_dims) {
×
195
                dim_b = layout_b_.get_dim_innermost(i);
×
196
            }
×
197
            if (dim_a.is_null() & !dim_b.is_null()) {
×
198
                factors.push_back(dim_b);
×
199
            } else if (!dim_a.is_null() & dim_b.is_null()) {
×
200
                factors.push_back(dim_a);
×
201
            } else if (!dim_a.is_null() & !dim_b.is_null()) {
×
202
                if (!symbolic::eq(dim_a, dim_b)) {
×
203
                    throw InvalidSDFGException(
×
204
                        "Batch dimension " + std::to_string(i) + " mismatch between A and B. A has " +
×
205
                        dim_a->__str__() + ", B has " + dim_b->__str__()
×
206
                    );
×
207
                } else {
×
208
                    factors.push_back(dim_a);
×
209
                }
×
210
            } else {
×
211
                return SymEngine::null;
×
212
            }
×
213
        }
×
214
        return SymEngine::mul(factors);
×
215
    } else {
×
216
        return per_mat;
×
217
    }
×
218
}
×
219

220
void free_after_copy(
221
    const std::string& copy_name, builder::StructuredSDFGBuilder& builder, structured_control_flow::Sequence& parent
222
) {
×
223
    auto& block = builder.add_block(parent, {}, DebugInfo());
×
224
    auto& access_in = builder.add_access(block, copy_name);
×
225
    auto& free_node = builder.add_library_node<stdlib::FreeNode>(block, DebugInfo());
×
226
    builder.add_computational_memlet(
×
227
        block, access_in, free_node, "_ptr", {}, types::Pointer(types::Scalar(types::PrimitiveType::Void))
×
228
    );
×
229
}
×
230

231
bool MatMulNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
232
    auto& dataflow = this->get_parent();
5✔
233
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
5✔
234

235
    if (dataflow.in_degree(*this) != 3 || dataflow.out_degree(*this) != 0) {
5✔
236
        return false;
×
237
    }
×
238

239
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
240
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
5✔
241
    int index = parent.index(block);
5✔
242
    auto& transition = parent.at(index).second;
5✔
243

244
    // Get input and output edges
245
    auto iedges = dataflow.in_edges_by_connector(*this);
5✔
246
    if (iedges.size() != 3) {
5✔
247
        return false;
×
248
    }
×
249
    auto* iedge_y = iedges.at(Y_INPUT_IDX);
5✔
250
    auto* iedge_a = iedges.at(A_INPUT_IDX);
5✔
251
    auto* iedge_b = iedges.at(B_INPUT_IDX);
5✔
252

253
    // Check if legal - access nodes must not have other connections
254
    auto& input_node_a = static_cast<data_flow::AccessNode&>(iedge_a->src());
5✔
255
    auto& input_node_b = static_cast<data_flow::AccessNode&>(iedge_b->src());
5✔
256
    auto& output_ptr = static_cast<data_flow::AccessNode&>(iedge_y->src());
5✔
257

258
    if (dataflow.in_degree(input_node_a) != 0 || dataflow.in_degree(input_node_b) != 0 ||
5✔
259
        dataflow.in_degree(output_ptr) != 0) {
5✔
260
        return false;
×
261
    }
×
262

263
    // Determine BLAS precision from primitive type
264
    auto prim_type = this->uniform_quantization(dataflow);
5✔
265
    if (!prim_type) {
5✔
266
        return false;
×
267
    }
×
268
    blas::BLAS_Precision precision;
5✔
269
    switch (prim_type.value()) {
5✔
270
        case types::PrimitiveType::Half:
×
271
            precision = blas::BLAS_Precision::h;
×
272
            break;
×
273
        case types::PrimitiveType::Float:
3✔
274
            precision = blas::BLAS_Precision::s;
3✔
275
            break;
3✔
276
        case types::PrimitiveType::Double:
1✔
277
            precision = blas::BLAS_Precision::d;
1✔
278
            break;
1✔
279
        default:
1✔
280
            // GEMM only supports floating point types, fall back to naive expansion
281
            return false;
1✔
282
    };
5✔
283

284
    // Add new graph after the current block
285
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
4✔
286

287
    auto copy_name_a = input_node_a.data();
4✔
288
    auto copy_name_b = input_node_b.data();
4✔
289

290
    // Check if A and B have basic strides and whether they are transposed in the last dimension
291
    blas::BLAS_Transpose trans_a, trans_b;
4✔
292
    if (layout_a_.has_linear_accesses_no_padding()) {
4✔
293
        trans_a = blas::BLAS_Transpose::No;
4✔
294
    } else if (layout_a_.has_transposed_strides_no_padding()) {
4✔
295
        trans_a = blas::BLAS_Transpose::Trans;
×
296
    } else {
×
297
        trans_a = blas::BLAS_Transpose::No;
×
298
        throw InvalidSDFGException("A must be in c-order");
×
299
    }
×
300
    if (layout_b_.has_linear_accesses_no_padding()) {
4✔
301
        trans_b = blas::BLAS_Transpose::No;
4✔
302
    } else if (layout_b_.has_transposed_strides_no_padding()) {
4✔
303
        trans_b = blas::BLAS_Transpose::Trans;
×
304
    } else {
×
305
        trans_b = blas::BLAS_Transpose::No;
×
306
        throw InvalidSDFGException("B must be in c-order");
×
307
    }
×
308

309
    // Create maps for batch dimensions and M, N dimensions
310
    structured_control_flow::Sequence* last_scope = &new_sequence;
4✔
311
    structured_control_flow::Map* last_map = nullptr;
4✔
312
    symbolic::MultiExpression batch_vars;
4✔
313

314
    // Compute batch dimensions (all except last 2)
315
    size_t batch_dims_a = layout_a_.dims() - 2;
4✔
316
    size_t batch_dims_b = layout_b_.dims() - 2;
4✔
317
    size_t max_batch_dims = std::max(batch_dims_a, batch_dims_b);
4✔
318

319
    // Create maps for batch dimensions (using broadcasting)
320
    for (size_t i = 0; i < max_batch_dims; ++i) {
5✔
321
        std::string indvar_str = builder.find_new_name("_b");
1✔
322
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
1✔
323

324
        auto indvar = symbolic::symbol(indvar_str);
1✔
325
        auto init = symbolic::zero();
1✔
326
        auto update = symbolic::add(indvar, symbolic::one());
1✔
327

328
        // Determine the bound for this batch dimension (max of A and B for broadcasting)
329
        symbolic::Expression bound;
1✔
330
        size_t a_idx = batch_dims_a >= (max_batch_dims - i) ? i - (max_batch_dims - batch_dims_a) : SIZE_MAX;
1✔
331
        size_t b_idx = batch_dims_b >= (max_batch_dims - i) ? i - (max_batch_dims - batch_dims_b) : SIZE_MAX;
1✔
332

333
        if (a_idx != SIZE_MAX && b_idx != SIZE_MAX) {
1✔
334
            // Both have this dimension - they should be equal or one should be 1 (broadcasting)
335
            bound = layout_a_.get_dim(a_idx); // Assume they match or broadcasting is handled
1✔
336
        } else if (a_idx != SIZE_MAX) {
1✔
337
            bound = layout_a_.get_dim(a_idx);
×
338
        } else {
×
339
            bound = layout_b_.get_dim(b_idx);
×
340
        }
×
341

342
        auto condition = symbolic::Lt(indvar, bound);
1✔
343
        last_map = &builder.add_map(
1✔
344
            *last_scope,
1✔
345
            indvar,
1✔
346
            condition,
1✔
347
            init,
1✔
348
            update,
1✔
349
            structured_control_flow::ScheduleType_Sequential::create(),
1✔
350
            {},
1✔
351
            block.debug_info()
1✔
352
        );
1✔
353
        last_scope = &last_map->root();
1✔
354
        batch_vars.push_back(indvar);
1✔
355
    }
1✔
356

357
    auto& ref_block = builder.add_block(*last_scope, {}, block.debug_info());
4✔
358

359
    auto scalar_type = types::Scalar(prim_type.value());
4✔
360

361
    // Compute offsets for this batch iteration
362
    // For A: base_offset_a = offset_a + sum_i(batch_idx_i * batch_stride_a_i)
363
    symbolic::Expression a_batch_offset = layout_a_.offset();
4✔
364
    for (size_t i = 0; i < batch_dims_a; ++i) {
5✔
365
        size_t batch_idx = max_batch_dims - batch_dims_a + i;
1✔
366
        a_batch_offset = symbolic::add(a_batch_offset, symbolic::mul(batch_vars[batch_idx], layout_a_.get_stride(i)));
1✔
367
    }
1✔
368

369
    // For B: base_offset_b = offset_b + sum_i(batch_idx_i * batch_stride_b_i)
370
    symbolic::Expression b_batch_offset = layout_b_.offset();
4✔
371
    for (size_t i = 0; i < batch_dims_b; ++i) {
5✔
372
        size_t batch_idx = max_batch_dims - batch_dims_b + i;
1✔
373
        b_batch_offset = symbolic::add(b_batch_offset, symbolic::mul(batch_vars[batch_idx], layout_b_.get_stride(i)));
1✔
374
    }
1✔
375

376
    // Compute output batch offset (same as batch_vars pattern for Y)
377
    symbolic::Expression c_batch_offset = symbolic::integer(0);
4✔
378
    for (size_t i = 0; i < batch_vars.size(); ++i) {
5✔
379
        // Output has shape [batch..., M, N] with row-major strides
380
        // Stride for batch dim i is: M * N * product of remaining batch dims
381
        symbolic::Expression c_stride = symbolic::mul(this->m(), this->n());
1✔
382
        for (size_t j = i + 1; j < batch_vars.size(); ++j) {
1✔
383
            // Multiply by subsequent batch dimensions
384
            if (j < batch_dims_a) {
×
385
                c_stride = symbolic::mul(c_stride, layout_a_.get_dim(j));
×
386
            } else if (j - batch_dims_a < batch_dims_b) {
×
387
                c_stride = symbolic::mul(c_stride, layout_b_.get_dim(j - batch_dims_a));
×
388
            }
×
389
        }
×
390
        c_batch_offset = symbolic::add(c_batch_offset, symbolic::mul(batch_vars[i], c_stride));
1✔
391
    }
1✔
392

393
    // Create access nodes
394
    auto& a_access = builder.add_access(ref_block, copy_name_a, debug_info());
4✔
395
    auto& b_access = builder.add_access(ref_block, copy_name_b, debug_info());
4✔
396
    auto& c_access_in = builder.add_access(ref_block, output_ptr.data(), debug_info());
4✔
397

398
    std::string ref_name_a = builder.find_new_name(copy_name_a + "_ref");
4✔
399
    builder.add_container(ref_name_a, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
400
    auto& a_access_ref = builder.add_access(ref_block, ref_name_a, debug_info());
4✔
401
    std::string ref_name_b = builder.find_new_name(copy_name_b + "_ref");
4✔
402
    builder.add_container(ref_name_b, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
403
    auto& b_access_ref = builder.add_access(ref_block, ref_name_b, debug_info());
4✔
404
    std::string ref_name_c = builder.find_new_name(output_ptr.data() + "_ref");
4✔
405
    builder.add_container(ref_name_c, types::Pointer(types::Scalar(types::PrimitiveType::Void)));
4✔
406
    auto& c_access_ref_in = builder.add_access(ref_block, ref_name_c, debug_info());
4✔
407

408
    builder.add_reference_memlet(
4✔
409
        ref_block, a_access, a_access_ref, {a_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
410
    );
4✔
411
    builder.add_reference_memlet(
4✔
412
        ref_block, b_access, b_access_ref, {b_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
413
    );
4✔
414
    builder.add_reference_memlet(
4✔
415
        ref_block, c_access_in, c_access_ref_in, {c_batch_offset}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
416
    );
4✔
417

418
    // Create block with GEMM library node
419
    auto& gemm_block = builder.add_block(*last_scope, {}, block.debug_info());
4✔
420

421
    // Leading dimensions: stride of the row dimension (second-to-last dim)
422
    symbolic::Expression lda, ldb;
4✔
423
    if (trans_a == blas::BLAS_Transpose::No) {
4✔
424
        // For row-major A [m * k] -> lda = k
425
        lda = layout_a_.get_stride_innermost(1);
4✔
426
    } else {
4✔
427
        // For row-major A [m * k] -> lda = m
428
        lda = layout_a_.get_stride_innermost(0);
×
429
    }
×
430
    if (trans_b == blas::BLAS_Transpose::No) {
4✔
431
        // For row-major B [k * n] -> ldb = n
432
        ldb = layout_b_.get_stride_innermost(1);
4✔
433
    } else {
4✔
434
        // For row-major B [k * n] -> ldb = k
435
        ldb = layout_b_.get_stride_innermost(0);
×
436
    }
×
437
    // For row-major C [m * n] -> ldc = n
438
    auto ldc = this->n();
4✔
439

440
    // Add GEMM node: C = alpha * A * B + beta * C
441
    // With alpha = 1.0, beta = 0.0: C = A * B
442
    auto& gemm_node = builder.add_library_node<blas::GEMMNode>(
4✔
443
        gemm_block,
4✔
444
        debug_info(),
4✔
445
        blas::ImplementationType_BLAS,
4✔
446
        precision,
4✔
447
        blas::BLAS_Layout::RowMajor,
4✔
448
        trans_a,
4✔
449
        trans_b,
4✔
450
        this->m(),
4✔
451
        this->n(),
4✔
452
        this->k(),
4✔
453
        lda,
4✔
454
        ldb,
4✔
455
        ldc
4✔
456
    );
4✔
457

458
    auto& a_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_a, debug_info());
4✔
459
    auto& b_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_b, debug_info());
4✔
460
    auto& c_access_ref_in_gemm = builder.add_access(gemm_block, ref_name_c, debug_info());
4✔
461

462
    // Create alpha and beta constants
463
    auto& alpha_const = builder.add_constant(gemm_block, "1.0", scalar_type, debug_info());
4✔
464
    auto& beta_const = builder.add_constant(gemm_block, "0.0", scalar_type, debug_info());
4✔
465

466
    // Connect memlets with batch offsets
467
    // Input A with offset
468
    builder.add_computational_memlet(
4✔
469
        gemm_block, a_access_ref_in_gemm, gemm_node, "__A", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
470
    );
4✔
471
    // Input B with offset
472
    builder.add_computational_memlet(
4✔
473
        gemm_block, b_access_ref_in_gemm, gemm_node, "__B", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
474
    );
4✔
475
    // Input C (for beta * C, but beta=0 so just needs to be connected)
476
    builder.add_computational_memlet(
4✔
477
        gemm_block, c_access_ref_in_gemm, gemm_node, "__C", {}, ::sdfg::types::Pointer(scalar_type), debug_info()
4✔
478
    );
4✔
479
    // Alpha constant
480
    builder.add_computational_memlet(gemm_block, alpha_const, gemm_node, "__alpha", {}, scalar_type, debug_info());
4✔
481
    // Beta constant
482
    builder.add_computational_memlet(gemm_block, beta_const, gemm_node, "__beta", {}, scalar_type, debug_info());
4✔
483

484
    // Free copies if we made them
485
    if (copy_name_a != input_node_a.data()) {
4✔
486
        free_after_copy(copy_name_a, builder, new_sequence);
×
487
    }
×
488
    if (copy_name_b != input_node_b.data()) {
4✔
489
        free_after_copy(copy_name_b, builder, new_sequence);
×
490
    }
×
491

492

493
    builder.clear_code_node_legacy(block, *this);
4✔
494
    // WARNING: this has been deallocated at this point!!
495
    builder.remove_child(parent, index + 1);
4✔
496

497
    return true;
4✔
498
}
4✔
499

500
nlohmann::json MatMulNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
501
    const MatMulNode& matmul_node = static_cast<const MatMulNode&>(library_node);
×
502
    nlohmann::json j;
×
503

504
    j["code"] = matmul_node.code().value();
×
505

506
    serializer::JSONSerializer serializer;
×
507

508
    matmul_node.layout_a().serialize_to_json(j["layout_a"]);
×
509
    matmul_node.layout_b().serialize_to_json(j["layout_b"]);
×
510

511
    j["result_quant"] = matmul_node.fixed_quantization();
×
512

513
    return j;
×
514
}
×
515

516
data_flow::LibraryNode& MatMulNodeSerializer::deserialize(
517
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
518
) {
×
519
    assert(j.contains("element_id"));
×
520
    assert(j.contains("code"));
×
521
    assert(j.contains("debug_info"));
×
522

523
    std::optional<TensorLayout> layout_a;
×
524
    std::optional<TensorLayout> layout_b;
×
525

526
    auto layout_a_it = j.find("layout_a");
×
527
    if (layout_a_it != j.end()) {
×
528
        layout_a = TensorLayout::deserialize_from_json(*layout_a_it);
×
529
        layout_b = TensorLayout::deserialize_from_json(j.at("layout_b"));
×
530

531
    } else {
×
532
        assert(j.contains("shape_a"));
×
533
        assert(j.contains("shape_b"));
×
534

535
        symbolic::MultiExpression shape_a;
×
536
        for (const auto& dim : j["shape_a"]) {
×
537
            shape_a.push_back(symbolic::parse(dim.get<std::string>()));
×
538
        }
×
539

540
        symbolic::MultiExpression shape_b;
×
541
        for (const auto& dim : j["shape_b"]) {
×
542
            shape_b.push_back(symbolic::parse(dim.get<std::string>()));
×
543
        }
×
544

545
        symbolic::MultiExpression strides_a;
×
546
        if (j.contains("strides_a")) {
×
547
            for (const auto& stride : j["strides_a"]) {
×
548
                strides_a.push_back(symbolic::parse(stride.get<std::string>()));
×
549
            }
×
550
        }
×
551

552
        symbolic::MultiExpression strides_b;
×
553
        if (j.contains("strides_b")) {
×
554
            for (const auto& stride : j["strides_b"]) {
×
555
                strides_b.push_back(symbolic::parse(stride.get<std::string>()));
×
556
            }
×
557
        }
×
558

559
        symbolic::Expression offset_a = symbolic::integer(0);
×
560
        if (j.contains("offset_a")) {
×
561
            offset_a = symbolic::parse(j["offset_a"].get<std::string>());
×
562
        }
×
563

564
        symbolic::Expression offset_b = symbolic::integer(0);
×
565
        if (j.contains("offset_b")) {
×
566
            offset_b = symbolic::parse(j["offset_b"].get<std::string>());
×
567
        }
×
568

569
        layout_a = TensorLayout(shape_a, strides_a, offset_a);
×
570
        layout_b = TensorLayout(shape_b, strides_b, offset_b);
×
571
    }
×
572

573
    auto quantization = deserialize_quantization(j, "result_quant", QUANTIZATION_MATCH_INPUTS);
×
574

575
    sdfg::serializer::JSONSerializer serializer;
×
576
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
577

578
    return builder.add_library_node<MatMulNode>(parent, debug_info, layout_a.value(), layout_b.value(), quantization);
×
579
}
×
580

581
} // namespace tensor
582
} // namespace math
583
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc