• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27981272983

22 Jun 2026 08:18PM UTC coverage: 61.754% (-0.03%) from 61.782%
27981272983

Pull #781

github

web-flow
Merge bddaa3724 into fe87d162b
Pull Request #781: Extend Segformer benchmarks setup

987 of 1432 new or added lines in 62 files covered. (68.92%)

9 existing lines in 7 files now uncovered.

38121 of 61730 relevant lines covered (61.75%)

993.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

45.75
/sdfg/src/data_flow/library_nodes/math/blas/gemm_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
namespace sdfg {
7
namespace math {
8
namespace blas {
9

10
GEMMNode::GEMMNode(
11
    size_t element_id,
12
    const DebugInfo& debug_info,
13
    const graph::Vertex vertex,
14
    data_flow::DataFlowGraph& parent,
15
    const data_flow::ImplementationType& implementation_type,
16
    const BLAS_Precision& precision,
17
    const BLAS_Layout& layout,
18
    const BLAS_Transpose& trans_a,
19
    const BLAS_Transpose& trans_b,
20
    symbolic::Expression m,
21
    symbolic::Expression n,
22
    symbolic::Expression k,
23
    symbolic::Expression lda,
24
    symbolic::Expression ldb,
25
    symbolic::Expression ldc
26
)
27
    : BLASNode(
35✔
28
          element_id,
35✔
29
          debug_info,
35✔
30
          vertex,
35✔
31
          parent,
35✔
32
          LibraryNodeType_GEMM,
35✔
33
          {},
35✔
34
          {"__A", "__B", "__C", "__alpha", "__beta"},
35✔
35
          implementation_type,
35✔
36
          precision
35✔
37
      ),
35✔
38
      layout_(layout), trans_a_(trans_a), trans_b_(trans_b), m_(m), n_(n), k_(k), lda_(lda), ldb_(ldb), ldc_(ldc) {}
35✔
39

40
BLAS_Layout GEMMNode::layout() const { return this->layout_; };
4✔
41

42
BLAS_Transpose GEMMNode::trans_a() const { return this->trans_a_; };
5✔
43

44
BLAS_Transpose GEMMNode::trans_b() const { return this->trans_b_; };
5✔
45

46
symbolic::Expression GEMMNode::m() const { return this->m_; };
19✔
47

48
symbolic::Expression GEMMNode::n() const { return this->n_; };
21✔
49

50
symbolic::Expression GEMMNode::k() const { return this->k_; };
19✔
51

52
symbolic::Expression GEMMNode::lda() const { return this->lda_; };
11✔
53

54
symbolic::Expression GEMMNode::ldb() const { return this->ldb_; };
11✔
55

56
symbolic::Expression GEMMNode::ldc() const { return this->ldc_; };
11✔
57

58
symbolic::SymbolSet GEMMNode::symbols() const {
×
59
    symbolic::SymbolSet syms;
×
60

61
    for (auto& atom : symbolic::atoms(this->m_)) {
×
62
        syms.insert(atom);
×
63
    }
×
64
    for (auto& atom : symbolic::atoms(this->n_)) {
×
65
        syms.insert(atom);
×
66
    }
×
67
    for (auto& atom : symbolic::atoms(this->k_)) {
×
68
        syms.insert(atom);
×
69
    }
×
70
    for (auto& atom : symbolic::atoms(this->lda_)) {
×
71
        syms.insert(atom);
×
72
    }
×
73
    for (auto& atom : symbolic::atoms(this->ldb_)) {
×
74
        syms.insert(atom);
×
75
    }
×
76
    for (auto& atom : symbolic::atoms(this->ldc_)) {
×
77
        syms.insert(atom);
×
78
    }
×
79

80
    return syms;
×
81
};
×
82

83
void GEMMNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
84
    this->m_ = symbolic::subs(this->m_, old_expression, new_expression);
×
85
    this->n_ = symbolic::subs(this->n_, old_expression, new_expression);
×
86
    this->k_ = symbolic::subs(this->k_, old_expression, new_expression);
×
87
    this->lda_ = symbolic::subs(this->lda_, old_expression, new_expression);
×
88
    this->ldb_ = symbolic::subs(this->ldb_, old_expression, new_expression);
×
89
    this->ldc_ = symbolic::subs(this->ldc_, old_expression, new_expression);
×
90
};
×
91

NEW
92
void GEMMNode::replace(const symbolic::ExpressionMapping& replacements) {
×
NEW
93
    this->m_ = symbolic::subs(this->m_, replacements);
×
NEW
94
    this->n_ = symbolic::subs(this->n_, replacements);
×
NEW
95
    this->k_ = symbolic::subs(this->k_, replacements);
×
NEW
96
    this->lda_ = symbolic::subs(this->lda_, replacements);
×
NEW
97
    this->ldb_ = symbolic::subs(this->ldb_, replacements);
×
NEW
98
    this->ldc_ = symbolic::subs(this->ldc_, replacements);
×
NEW
99
};
×
100

101
void GEMMNode::validate(const Function& function) const { BLASNode::validate(function); }
14✔
102

103
bool GEMMNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
8✔
104
    auto& dataflow = this->get_parent();
8✔
105
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
8✔
106
    auto& parent = static_cast<structured_control_flow::Sequence&>(*block.get_parent());
8✔
107
    int index = parent.index(block);
8✔
108
    auto& transition = parent.at(index).second;
8✔
109

110
    if (trans_a_ == BLAS_Transpose::ConjTrans || trans_b_ == BLAS_Transpose::ConjTrans) {
8✔
111
        return false;
×
112
    }
×
113

114
    auto primitive_type = scalar_primitive();
8✔
115
    if (primitive_type == types::PrimitiveType::Void) {
8✔
116
        return false;
×
117
    }
×
118

119
    types::Scalar scalar_type(primitive_type);
8✔
120

121
    auto in_edges = dataflow.in_edges(*this);
8✔
122
    auto in_edges_it = in_edges.begin();
8✔
123

124
    data_flow::Memlet* iedge_a = nullptr;
8✔
125
    data_flow::Memlet* iedge_b = nullptr;
8✔
126
    data_flow::Memlet* iedge_c = nullptr;
8✔
127
    data_flow::Memlet* alpha_edge = nullptr;
8✔
128
    data_flow::Memlet* beta_edge = nullptr;
8✔
129
    while (in_edges_it != in_edges.end()) {
48✔
130
        auto& edge = *in_edges_it;
40✔
131
        auto dst_conn = edge.dst_conn();
40✔
132
        if (dst_conn == "__A") {
40✔
133
            iedge_a = &edge;
8✔
134
        } else if (dst_conn == "__B") {
32✔
135
            iedge_b = &edge;
8✔
136
        } else if (dst_conn == "__C") {
24✔
137
            iedge_c = &edge;
8✔
138
        } else if (dst_conn == "__alpha") {
16✔
139
            alpha_edge = &edge;
8✔
140
        } else if (dst_conn == "__beta") {
8✔
141
            beta_edge = &edge;
8✔
142
        } else {
8✔
143
            throw InvalidSDFGException("GEMMNode has unexpected input: " + dst_conn);
×
144
        }
×
145
        ++in_edges_it;
40✔
146
    }
40✔
147

148
    // Checks if legal
149
    auto* input_node_a = static_cast<data_flow::AccessNode*>(&iedge_a->src());
8✔
150
    auto* input_node_b = static_cast<data_flow::AccessNode*>(&iedge_b->src());
8✔
151
    auto* input_node_c = static_cast<data_flow::AccessNode*>(&iedge_c->src());
8✔
152
    auto* alpha_node = static_cast<data_flow::AccessNode*>(&alpha_edge->src());
8✔
153
    auto* beta_node = static_cast<data_flow::AccessNode*>(&beta_edge->src());
8✔
154

155
    // we must be the only thing in this block, as we do not support splitting a block into pre, expanded lib-node, post
156
    if (!input_node_a || dataflow.in_degree(*input_node_a) != 0 || !input_node_b ||
8✔
157
        dataflow.in_degree(*input_node_b) != 0 || !input_node_c || dataflow.in_degree(*input_node_c) != 0) {
8✔
158
        return false; // data nodes are not standalone
×
159
    }
×
160
    if (dataflow.in_degree(*alpha_node) != 0 || dataflow.in_degree(*beta_node) != 0) {
8✔
161
        return false; // alpha and beta are not standalone
×
162
    }
×
163
    for (auto* nd : dataflow.data_nodes()) {
40✔
164
        if (nd != input_node_a && nd != input_node_b && nd != input_node_c && (!alpha_node || nd != alpha_node) &&
40✔
165
            (!beta_node || nd != beta_node)) {
40✔
166
            return false; // there are other nodes in here that we could not preserve correctly
×
167
        }
×
168
    }
40✔
169

170
    auto& A_var = input_node_a->data();
8✔
171
    auto& B_var = input_node_b->data();
8✔
172
    auto& C_ptr = input_node_c->data();
8✔
173

174

175
    // Add new graph after the current block
176
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
8✔
177

178
    // Add maps
179
    std::vector<symbolic::Expression> indvar_ends{this->m(), this->n(), this->k()};
8✔
180
    data_flow::Subset new_subset;
8✔
181
    structured_control_flow::Sequence* last_scope = &new_sequence;
8✔
182
    structured_control_flow::StructuredLoop* last_map = nullptr;
8✔
183
    structured_control_flow::StructuredLoop* output_loop = nullptr;
8✔
184
    std::vector<std::string> indvar_names{"_i", "_j", "_k"};
8✔
185

186
    std::string sum_var = builder.find_new_name("_sum");
8✔
187
    builder.add_container(sum_var, scalar_type);
8✔
188

189
    for (size_t i = 0; i < 3; i++) {
32✔
190
        auto dim_begin = symbolic::zero();
24✔
191
        auto& dim_end = indvar_ends[i];
24✔
192

193
        std::string indvar_str = builder.find_new_name(indvar_names[i]);
24✔
194
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
24✔
195

196
        auto indvar = symbolic::symbol(indvar_str);
24✔
197
        auto init = dim_begin;
24✔
198
        auto update = symbolic::add(indvar, symbolic::one());
24✔
199
        auto condition = symbolic::Lt(indvar, dim_end);
24✔
200
        if (i < 2) {
24✔
201
            last_map = &builder.add_map(
16✔
202
                *last_scope,
16✔
203
                indvar,
16✔
204
                condition,
16✔
205
                init,
16✔
206
                update,
16✔
207
                structured_control_flow::ScheduleType_Sequential::create(),
16✔
208
                {},
16✔
209
                block.debug_info()
16✔
210
            );
16✔
211
        } else {
16✔
212
            last_map = &builder.add_for(*last_scope, indvar, condition, init, update, {}, block.debug_info());
8✔
213
        }
8✔
214
        last_scope = &last_map->root();
24✔
215

216
        if (i == 1) {
24✔
217
            output_loop = last_map;
8✔
218
        }
8✔
219

220
        new_subset.push_back(indvar);
24✔
221
    }
24✔
222

223

224
    // Add code
225
    auto& init_block = builder.add_block_before(output_loop->root(), *last_map, {}, block.debug_info());
8✔
226
    auto& sum_init = builder.add_access(init_block, sum_var, block.debug_info());
8✔
227

228
    auto& zero_node = builder.add_constant(init_block, "0.0", alpha_edge->base_type(), block.debug_info());
8✔
229
    auto& init_tasklet = builder.add_tasklet(init_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
8✔
230
    builder.add_computational_memlet(init_block, zero_node, init_tasklet, "_in", {}, block.debug_info());
8✔
231
    builder.add_computational_memlet(init_block, init_tasklet, "_out", sum_init, {}, block.debug_info());
8✔
232

233
    auto& code_block = builder.add_block(*last_scope, {}, block.debug_info());
8✔
234
    auto& input_node_a_new = builder.add_access(code_block, A_var, input_node_a->debug_info());
8✔
235
    auto& input_node_b_new = builder.add_access(code_block, B_var, input_node_b->debug_info());
8✔
236

237
    auto& core_fma =
8✔
238
        builder.add_tasklet(code_block, data_flow::fp_fma, "_out", {"_in1", "_in2", "_in3"}, block.debug_info());
8✔
239
    auto& sum_in = builder.add_access(code_block, sum_var, block.debug_info());
8✔
240
    auto& sum_out = builder.add_access(code_block, sum_var, block.debug_info());
8✔
241
    builder.add_computational_memlet(code_block, sum_in, core_fma, "_in3", {}, block.debug_info());
8✔
242

243
    // Row-major indexing: address = ld * row + col
244
    // No transpose: A is m×k, access A[i, k] => lda*i + k
245
    // Transpose:    A is k×m stored, access A[k, i] => lda*k + i
246
    symbolic::Expression a_idx = (trans_a_ == BLAS_Transpose::Trans)
8✔
247
                                     ? symbolic::add(symbolic::mul(lda(), new_subset[2]), new_subset[0])
8✔
248
                                     : symbolic::add(symbolic::mul(lda(), new_subset[0]), new_subset[2]);
8✔
249
    builder.add_computational_memlet(
8✔
250
        code_block, input_node_a_new, core_fma, "_in1", {a_idx}, iedge_a->base_type(), iedge_a->debug_info()
8✔
251
    );
8✔
252
    // No transpose: B is k×n, access B[k, j] => ldb*k + j
253
    // Transpose:    B is n×k stored, access B[j, k] => ldb*j + k
254
    symbolic::Expression b_idx = (trans_b_ == BLAS_Transpose::Trans)
8✔
255
                                     ? symbolic::add(symbolic::mul(ldb(), new_subset[1]), new_subset[2])
8✔
256
                                     : symbolic::add(symbolic::mul(ldb(), new_subset[2]), new_subset[1]);
8✔
257
    builder.add_computational_memlet(
8✔
258
        code_block, input_node_b_new, core_fma, "_in2", {b_idx}, iedge_b->base_type(), iedge_b->debug_info()
8✔
259
    );
8✔
260
    builder.add_computational_memlet(code_block, core_fma, "_out", sum_out, {}, iedge_c->debug_info());
8✔
261

262
    auto& flush_block = builder.add_block_after(output_loop->root(), *last_map, {}, block.debug_info());
8✔
263
    auto& sum_final = builder.add_access(flush_block, sum_var, block.debug_info());
8✔
264
    auto& input_node_c_new = builder.add_access(flush_block, C_ptr, input_node_c->debug_info());
8✔
265
    symbolic::Expression c_idx = symbolic::add(symbolic::mul(ldc(), new_subset[0]), new_subset[1]);
8✔
266

267
    auto& scale_sum_tasklet =
8✔
268
        builder.add_tasklet(flush_block, data_flow::TaskletCode::fp_mul, "_out", {"_in1", "_in2"}, block.debug_info());
8✔
269
    builder.add_computational_memlet(flush_block, sum_final, scale_sum_tasklet, "_in1", {}, block.debug_info());
8✔
270
    if (auto const_node = dynamic_cast<data_flow::ConstantNode*>(alpha_node)) {
8✔
271
        auto& alpha_node_new =
8✔
272
            builder.add_constant(flush_block, const_node->data(), const_node->type(), block.debug_info());
8✔
273
        builder.add_computational_memlet(flush_block, alpha_node_new, scale_sum_tasklet, "_in2", {}, block.debug_info());
8✔
274
    } else {
8✔
275
        auto& alpha_node_new = builder.add_access(flush_block, alpha_node->data(), block.debug_info());
×
276
        builder.add_computational_memlet(flush_block, alpha_node_new, scale_sum_tasklet, "_in2", {}, block.debug_info());
×
277
    }
×
278

279
    std::string scaled_sum_temp = builder.find_new_name("scaled_sum_temp");
8✔
280
    builder.add_container(scaled_sum_temp, scalar_type);
8✔
281
    auto& scaled_sum_final = builder.add_access(flush_block, scaled_sum_temp, block.debug_info());
8✔
282
    builder.add_computational_memlet(
8✔
283
        flush_block, scale_sum_tasklet, "_out", scaled_sum_final, {}, scalar_type, block.debug_info()
8✔
284
    );
8✔
285

286
    auto& scale_input_tasklet =
8✔
287
        builder.add_tasklet(flush_block, data_flow::TaskletCode::fp_mul, "_out", {"_in1", "_in2"}, block.debug_info());
8✔
288
    builder.add_computational_memlet(
8✔
289
        flush_block, input_node_c_new, scale_input_tasklet, "_in1", {c_idx}, iedge_c->base_type(), iedge_c->debug_info()
8✔
290
    );
8✔
291
    if (auto const_node = dynamic_cast<data_flow::ConstantNode*>(beta_node)) {
8✔
292
        auto& beta_node_new =
8✔
293
            builder.add_constant(flush_block, const_node->data(), const_node->type(), block.debug_info());
8✔
294
        builder
8✔
295
            .add_computational_memlet(flush_block, beta_node_new, scale_input_tasklet, "_in2", {}, block.debug_info());
8✔
296
    } else {
8✔
297
        auto& beta_node_new = builder.add_access(flush_block, beta_node->data(), block.debug_info());
×
298
        builder
×
299
            .add_computational_memlet(flush_block, beta_node_new, scale_input_tasklet, "_in2", {}, block.debug_info());
×
300
    }
×
301

302
    std::string scaled_input_temp = builder.find_new_name("scaled_input_temp");
8✔
303
    builder.add_container(scaled_input_temp, scalar_type);
8✔
304
    auto& scaled_input_c = builder.add_access(flush_block, scaled_input_temp, block.debug_info());
8✔
305
    builder.add_computational_memlet(
8✔
306
        flush_block, scale_input_tasklet, "_out", scaled_input_c, {}, scalar_type, block.debug_info()
8✔
307
    );
8✔
308

309
    auto& flush_add_tasklet =
8✔
310
        builder.add_tasklet(flush_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, block.debug_info());
8✔
311
    auto& output_node_new = builder.add_access(flush_block, C_ptr, input_node_c->debug_info());
8✔
312
    builder.add_computational_memlet(
8✔
313
        flush_block, scaled_sum_final, flush_add_tasklet, "_in1", {}, scalar_type, block.debug_info()
8✔
314
    );
8✔
315
    builder.add_computational_memlet(
8✔
316
        flush_block, scaled_input_c, flush_add_tasklet, "_in2", {}, scalar_type, block.debug_info()
8✔
317
    );
8✔
318
    builder.add_computational_memlet(
8✔
319
        flush_block, flush_add_tasklet, "_out", output_node_new, {c_idx}, iedge_c->base_type(), iedge_c->debug_info()
8✔
320
    );
8✔
321

322

323
    // Clean up block
324
    builder.remove_memlet(block, *iedge_a);
8✔
325
    builder.remove_memlet(block, *iedge_b);
8✔
326
    builder.remove_memlet(block, *iedge_c);
8✔
327
    builder.remove_memlet(block, *alpha_edge);
8✔
328
    builder.remove_node(block, *alpha_node);
8✔
329
    builder.remove_memlet(block, *beta_edge);
8✔
330
    builder.remove_node(block, *beta_node);
8✔
331
    builder.remove_node(block, *input_node_a);
8✔
332
    builder.remove_node(block, *input_node_b);
8✔
333
    builder.remove_node(block, *input_node_c);
8✔
334
    builder.remove_node(block, *this);
8✔
335
    builder.remove_child(parent, index + 1);
8✔
336

337
    return true;
8✔
338
}
8✔
339

340
symbolic::Expression GEMMNode::flop() const {
×
341
    return flops(symbolic::__true__(), symbolic::__true__(), symbolic::__true__(), symbolic::__true__());
×
342
}
×
343

344
symbolic::Expression GEMMNode::flops(
345
    symbolic::Condition alpha_non_zero,
346
    symbolic::Condition alpha_non_ident,
347
    symbolic::Condition beta_non_zero,
348
    symbolic::Condition beta_non_ident
349
) const {
×
350
    auto res_elems = symbolic::mul(this->m_, this->n_);
×
351

352
    // conditional on alpha != 0.0
353
    auto mm_mul_ops = symbolic::mul(symbolic::mul(res_elems, this->k_), alpha_non_zero);
×
354
    auto mm_sum_ops = symbolic::mul(symbolic::mul(res_elems, symbolic::sub(this->k_, symbolic::one())), alpha_non_zero);
×
355
    // conditional on alpha != 1.0 && alpha != 0.0
356
    auto mm_alpha_scale_ops = symbolic::mul(res_elems, symbolic::And(alpha_non_ident, alpha_non_zero));
×
357
    // conditional on beta != 1.0 && beta != 0.0
358
    auto mm_beta_scale_ops = symbolic::mul(res_elems, symbolic::And(beta_non_ident, beta_non_zero));
×
359
    auto mm_beta_scaled_sum_ops = symbolic::mul(res_elems, beta_non_zero);
×
360
    auto mul_ops = symbolic::add(mm_mul_ops, symbolic::add(mm_alpha_scale_ops, mm_beta_scale_ops));
×
361
    auto add_ops = symbolic::add(mm_sum_ops, mm_beta_scaled_sum_ops);
×
362
    return symbolic::add(mul_ops, add_ops);
×
363
}
×
364

365
std::unique_ptr<data_flow::DataFlowNode> GEMMNode::
366
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
367
    auto node_clone = std::unique_ptr<GEMMNode>(new GEMMNode(
×
368
        element_id,
×
369
        this->debug_info(),
×
370
        vertex,
×
371
        parent,
×
372
        this->implementation_type_,
×
373
        this->precision_,
×
374
        this->layout_,
×
375
        this->trans_a_,
×
376
        this->trans_b_,
×
377
        this->m_,
×
378
        this->n_,
×
379
        this->k_,
×
380
        this->lda_,
×
381
        this->ldb_,
×
382
        this->ldc_
×
383
    ));
×
384
    return std::move(node_clone);
×
385
}
×
386

387
std::string GEMMNode::toStr() const {
×
388
    return LibraryNode::toStr() + "(" + static_cast<char>(precision_) + ", " +
×
389
           std::string(BLAS_Layout_to_short_string(layout_)) + ", " + BLAS_Transpose_to_char(trans_a_) +
×
390
           BLAS_Transpose_to_char(trans_b_) + ", " + m_->__str__() + ", " + n_->__str__() + ", " + k_->__str__() +
×
391
           ", " + lda_->__str__() + ", " + ldb_->__str__() + ", " + ldc_->__str__() + ")";
×
392
}
×
393

394
symbolic::Expression GEMMNode::calc_matrix_access_range(
395
    const symbolic::Expression& outer_dim,
396
    const symbolic::Expression& inner_dim,
397
    const symbolic::Expression& line_size,
398
    BLAS_Transpose trans,
399
    BLAS_Layout layout
400
) {
×
401
    if ((trans == BLAS_Transpose::No) ^ (layout == BLAS_Layout::ColMajor)) {
×
402
        return symbolic::mul(outer_dim, line_size);
×
403
    } else {
×
404
        return symbolic::mul(inner_dim, line_size);
×
405
    }
×
406
}
×
407

408

409
data_flow::PointerAccessType GEMMNode::pointer_access_type(int input_idx) const {
×
410
    if (input_idx == 0) { // A: m x k
×
411
        return data_flow::PointerAccessMeta::
×
412
            create_read_only(calc_matrix_access_range(m_, k_, lda_, trans_a_, layout_), true);
×
413
    } else if (input_idx == 1) { // B: k x n
×
414
        return data_flow::PointerAccessMeta::
×
415
            create_read_only(calc_matrix_access_range(k_, n_, ldb_, trans_b_, layout_), true);
×
416
    } else if (input_idx == 2) {
×
417
        // for beta == 0, there would no reads of C. But we currently have no mechanism to access const-prop knowledge
418
        // like tha
419
        if (symbolic::eq(ldc_, n_)) { // non-sparse access over the m x n range
×
420
            return data_flow::PointerAccessMeta::
×
421
                create_full_write_only(calc_matrix_access_range(m_, n_, ldc_, BLAS_Transpose::No, layout_), true);
×
422
        } else {
×
423
            // sparse access. But with only Convex Pattern for now, we cannot represent which values are
424
            auto pattern =
×
425
                data_flow::ConvexAccessPattern::create(calc_matrix_access_range(m_, n_, ldc_, BLAS_Transpose::No, layout_)
×
426
                );
×
427
            // full-overwritten and which are DC.
428
            return data_flow::PointerAccessMeta::create_generic(pattern->ref(), std::move(pattern), true);
×
429
        }
×
430
    } else {
×
431
        return LibraryNode::pointer_access_type(input_idx);
×
432
    }
×
433
}
×
434

435
nlohmann::json GEMMNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
436
    const GEMMNode& gemm_node = static_cast<const GEMMNode&>(library_node);
×
437
    nlohmann::json j;
×
438

439
    serializer::JSONSerializer serializer;
×
440
    j["code"] = gemm_node.code().value();
×
441
    j["precision"] = gemm_node.precision();
×
442
    j["layout"] = gemm_node.layout();
×
443
    j["trans_a"] = gemm_node.trans_a();
×
444
    j["trans_b"] = gemm_node.trans_b();
×
445
    j["m"] = serializer.expression(gemm_node.m());
×
446
    j["n"] = serializer.expression(gemm_node.n());
×
447
    j["k"] = serializer.expression(gemm_node.k());
×
448
    j["lda"] = serializer.expression(gemm_node.lda());
×
449
    j["ldb"] = serializer.expression(gemm_node.ldb());
×
450
    j["ldc"] = serializer.expression(gemm_node.ldc());
×
451

452
    return j;
×
453
}
×
454

455
data_flow::LibraryNode& GEMMNodeSerializer::deserialize(
456
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
457
) {
×
458
    // Assertions for required fields
459
    assert(j.contains("element_id"));
×
460
    assert(j.contains("code"));
×
461
    assert(j.contains("debug_info"));
×
462

463
    auto code = j["code"].get<std::string>();
×
464
    if (code != LibraryNodeType_GEMM.value()) {
×
465
        throw std::runtime_error("Invalid library node code");
×
466
    }
×
467

468
    // Extract debug info using JSONSerializer
469
    sdfg::serializer::JSONSerializer serializer;
×
470
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
471

472
    auto precision = j.at("precision").get<BLAS_Precision>();
×
473
    auto layout = j.at("layout").get<BLAS_Layout>();
×
474
    auto trans_a = j.at("trans_a").get<BLAS_Transpose>();
×
475
    auto trans_b = j.at("trans_b").get<BLAS_Transpose>();
×
476
    auto m = symbolic::parse(j.at("m"));
×
477
    auto n = symbolic::parse(j.at("n"));
×
478
    auto k = symbolic::parse(j.at("k"));
×
479
    auto lda = symbolic::parse(j.at("lda"));
×
480
    auto ldb = symbolic::parse(j.at("ldb"));
×
481
    auto ldc = symbolic::parse(j.at("ldc"));
×
482

483
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
484

485
    return builder.add_library_node<
×
486
        GEMMNode>(parent, debug_info, implementation_type, precision, layout, trans_a, trans_b, m, n, k, lda, ldb, ldc);
×
487
}
×
488

489
GEMMNodeDispatcher_BLAS::GEMMNodeDispatcher_BLAS(
490
    codegen::LanguageExtension& language_extension,
491
    const Function& function,
492
    const data_flow::DataFlowGraph& data_flow_graph,
493
    const GEMMNode& node
494
)
495
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
496

497
void GEMMNodeDispatcher_BLAS::dispatch_code_with_edges(
498
    codegen::CodegenOutput& out,
499
    std::vector<codegen::DispatchInput>& inputs,
500
    std::vector<codegen::DispatchOutput>& outputs
501
) {
×
502
    auto& gemm_node = static_cast<const GEMMNode&>(this->node_);
×
503

504
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
505
    switch (gemm_node.precision()) {
×
506
        case BLAS_Precision::h:
×
507
            base_type = types::Scalar(types::PrimitiveType::Half);
×
508
            break;
×
509
        case BLAS_Precision::s:
×
510
            base_type = types::Scalar(types::PrimitiveType::Float);
×
511
            break;
×
512
        case BLAS_Precision::d:
×
513
            base_type = types::Scalar(types::PrimitiveType::Double);
×
514
            break;
×
515
        default:
×
516
            throw std::runtime_error("Invalid BLAS_Precision value");
×
517
    }
×
518

519
    out.library_snippet_factory.require_dependency(BLASLibDependency::instance());
×
520

521
    out.stream << "cblas_" << BLAS_Precision_to_string(gemm_node.precision()) << "gemm(";
×
522
    out.stream.changeIndent(+4);
×
523
    out.stream << BLAS_Layout_to_string(gemm_node.layout());
×
524
    out.stream << ", ";
×
525
    out.stream << BLAS_Transpose_to_string(gemm_node.trans_a());
×
526
    out.stream << ", ";
×
527
    out.stream << BLAS_Transpose_to_string(gemm_node.trans_b());
×
528
    out.stream << ", ";
×
529
    out.stream << this->language_extension_.expression(gemm_node.m());
×
530
    out.stream << ", ";
×
531
    out.stream << this->language_extension_.expression(gemm_node.n());
×
532
    out.stream << ", ";
×
533
    out.stream << this->language_extension_.expression(gemm_node.k());
×
534
    out.stream << ", ";
×
535
    out.stream << inputs.at(GEMMNode::ALPHA_INPUT_IDX).expr;
×
536
    out.stream << ", ";
×
537
    out.stream << inputs.at(GEMMNode::A_INPUT_IDX).expr;
×
538
    out.stream << ", ";
×
539
    out.stream << this->language_extension_.expression(gemm_node.lda());
×
540
    out.stream << ", ";
×
541
    out.stream << inputs.at(GEMMNode::B_INPUT_IDX).expr;
×
542
    out.stream << ", ";
×
543
    out.stream << this->language_extension_.expression(gemm_node.ldb());
×
544
    out.stream << ", ";
×
545
    out.stream << inputs.at(GEMMNode::BETA_INPUT_IDX).expr;
×
546
    out.stream << ", ";
×
547
    out.stream << inputs.at(GEMMNode::C_INPUT_IDX).expr;
×
548
    out.stream << ", ";
×
549
    out.stream << this->language_extension_.expression(gemm_node.ldc());
×
550

551
    out.stream.changeIndent(-4);
×
552
    out.stream << ");" << std::endl;
×
553
}
×
554

555
GEMMNode& add_gemm_node(
556
    builder::StructuredSDFGBuilder& builder,
557
    Block& block,
558
    const std::string& ptr_a,
559
    const std::string& ptr_b,
560
    const std::string& ptr_c,
561
    data_flow::AccessNode& alpha_node,
562
    data_flow::AccessNode& beta_node,
563
    const BLAS_Precision& precision,
564
    const BLAS_Layout& layout,
565
    const BLAS_Transpose& trans_a,
566
    const BLAS_Transpose& trans_b,
567
    symbolic::Expression& m,
568
    symbolic::Expression& n,
569
    symbolic::Expression& k,
570
    symbolic::Expression& lda,
571
    symbolic::Expression& ldb,
572
    symbolic::Expression& ldc,
573
    const types::IType& a_type,
574
    const types::IType& b_type,
575
    const types::IType& c_type,
576
    const types::IType& factor_type,
577
    DebugInfo debug_info,
578
    DebugInfo a_access_deb_info,
579
    DebugInfo b_access_deb_info,
580
    DebugInfo c_access_deb_info,
581
    DebugInfo a_edge_deb_info,
582
    DebugInfo b_edge_deb_info,
583
    DebugInfo c_edge_deb_info,
584
    data_flow::ImplementationType impl_type
585
) {
6✔
586
    auto& gemm_node = builder.add_library_node<sdfg::math::blas::GEMMNode>(
6✔
587
        block, debug_info, std::move(impl_type), precision, layout, trans_a, trans_b, m, n, k, lda, ldb, ldc
6✔
588
    );
6✔
589

590
    // Add access nodes
591
    auto& a_node_in = builder.add_access(block, ptr_a, a_access_deb_info);
6✔
592
    auto& b_node_in = builder.add_access(block, ptr_b, b_access_deb_info);
6✔
593
    auto& c_node_in = builder.add_access(block, ptr_c, c_access_deb_info);
6✔
594

595
    // Add edges
596
    builder.add_computational_memlet(block, a_node_in, gemm_node, "__A", {}, a_type, a_edge_deb_info);
6✔
597
    builder.add_computational_memlet(block, b_node_in, gemm_node, "__B", {}, b_type, b_edge_deb_info);
6✔
598
    builder.add_computational_memlet(block, c_node_in, gemm_node, "__C", {}, c_type, c_edge_deb_info);
6✔
599
    builder.add_computational_memlet(block, alpha_node, gemm_node, "__alpha", {}, factor_type, debug_info);
6✔
600
    builder.add_computational_memlet(block, beta_node, gemm_node, "__beta", {}, factor_type, debug_info);
6✔
601

602
    return static_cast<GEMMNode&>(gemm_node);
6✔
603
}
6✔
604

605
GEMMNode& add_gemm_node(
606
    builder::StructuredSDFGBuilder& builder,
607
    Block& block,
608
    const std::string& ptr_a,
609
    const std::string& ptr_b,
610
    const std::string& ptr_c,
611
    data_flow::AccessNode& alpha_node,
612
    data_flow::AccessNode& beta_node,
613
    const BLAS_Precision& precision,
614
    const BLAS_Layout& layout,
615
    const BLAS_Transpose& trans_a,
616
    const BLAS_Transpose& trans_b,
617
    symbolic::Expression& m,
618
    symbolic::Expression& n,
619
    symbolic::Expression& k,
620
    symbolic::Expression& lda,
621
    symbolic::Expression& ldb,
622
    symbolic::Expression& ldc,
623
    const types::IType& ptr_type,
624
    const types::IType& factor_type,
625
    DebugInfo debug_info,
626
    data_flow::ImplementationType impl_type
627
) {
×
628
    return add_gemm_node(
×
629
        builder,
×
630
        block,
×
631
        ptr_a,
×
632
        ptr_b,
×
633
        ptr_c,
×
634
        alpha_node,
×
635
        beta_node,
×
636
        precision,
×
637
        layout,
×
638
        trans_a,
×
639
        trans_b,
×
640
        m,
×
641
        n,
×
642
        k,
×
643
        lda,
×
644
        ldb,
×
645
        ldc,
×
646
        ptr_type,
×
647
        ptr_type,
×
648
        ptr_type,
×
649
        factor_type,
×
650
        debug_info,
×
651
        debug_info,
×
652
        debug_info,
×
653
        debug_info,
×
654
        debug_info,
×
655
        debug_info,
×
656
        debug_info,
×
657
        impl_type
×
658
    );
×
659
}
×
660

661
} // namespace blas
662
} // namespace math
663
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc