• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 17637380013

11 Sep 2025 07:29AM UTC coverage: 59.755% (+0.6%) from 59.145%
17637380013

push

github

web-flow
New debug info (#210)

* initial draft

* update data structure and construction logic

* finalize DebugInfo draft

* fix tests

* Update serializer and fix tests

* fix append bug

* update data structure

* sdfg builder update

* const ref vectors

* update implementation and partial tests

* compiling state

* update serializer interface

* update dot test

* reset interface to debug_info in json to maintain compatibility with tools

* first review batch

* second batch of changes

* merge fixes

777 of 1111 new or added lines in 46 files covered. (69.94%)

11 existing lines in 11 files now uncovered.

9755 of 16325 relevant lines covered (59.75%)

115.06 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

53.56
/src/data_flow/library_nodes/math/blas/gemm.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/gemm.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
#include "sdfg/analysis/scope_analysis.h"
7

8
namespace sdfg {
9
namespace math {
10
namespace blas {
11

12
GEMMNode::GEMMNode(
1✔
13
    size_t element_id,
14
    const DebugInfoRegion& debug_info,
15
    const graph::Vertex vertex,
16
    data_flow::DataFlowGraph& parent,
17
    const data_flow::ImplementationType& implementation_type,
18
    const BLAS_Precision& precision,
19
    const BLAS_Layout& layout,
20
    const BLAS_Transpose& trans_a,
21
    const BLAS_Transpose& trans_b,
22
    symbolic::Expression m,
23
    symbolic::Expression n,
24
    symbolic::Expression k,
25
    symbolic::Expression lda,
26
    symbolic::Expression ldb,
27
    symbolic::Expression ldc,
28
    const std::string& alpha,
29
    const std::string& beta
30
)
31
    : MathNode(
1✔
32
          element_id,
1✔
33
          debug_info,
1✔
34
          vertex,
1✔
35
          parent,
1✔
36
          LibraryNodeType_GEMM,
37
          {"C"},
1✔
38
          {"A", "B", "C", alpha, beta},
1✔
39
          implementation_type
1✔
40
      ),
41
      precision_(precision), layout_(layout), trans_a_(trans_a), trans_b_(trans_b), m_(m), n_(n), k_(k), lda_(lda),
1✔
42
      ldb_(ldb), ldc_(ldc) {}
2✔
43

44
BLAS_Precision GEMMNode::precision() const { return this->precision_; };
×
45

46
BLAS_Layout GEMMNode::layout() const { return this->layout_; };
×
47

48
BLAS_Transpose GEMMNode::trans_a() const { return this->trans_a_; };
×
49

50
BLAS_Transpose GEMMNode::trans_b() const { return this->trans_b_; };
×
51

52
symbolic::Expression GEMMNode::m() const { return this->m_; };
1✔
53

54
symbolic::Expression GEMMNode::n() const { return this->n_; };
1✔
55

56
symbolic::Expression GEMMNode::k() const { return this->k_; };
1✔
57

58
symbolic::Expression GEMMNode::lda() const { return this->lda_; };
1✔
59

60
symbolic::Expression GEMMNode::ldb() const { return this->ldb_; };
1✔
61

62
symbolic::Expression GEMMNode::ldc() const { return this->ldc_; };
1✔
63

64
const std::string& GEMMNode::alpha() const { return this->inputs_.at(3); };
1✔
65

66
const std::string& GEMMNode::beta() const { return this->inputs_.at(4); };
1✔
67

68
void GEMMNode::validate(const Function& function) const {
×
69
    auto& graph = this->get_parent();
×
70

71
    if (this->inputs_.size() != 5) {
×
72
        throw InvalidSDFGException("GEMMNode must have 5 inputs: A, B, C, (alpha), (beta)");
×
73
    }
74

75
    int input_edge_count = graph.in_degree(*this);
×
76
    if (input_edge_count < 3 || input_edge_count > 5) {
×
77
        throw InvalidSDFGException("GEMMNode must have 3-5 inputs");
×
78
    }
79
    if (graph.out_degree(*this) != 1) {
×
80
        throw InvalidSDFGException("GEMMNode must have 1 output");
×
81
    }
82

83
    // // Check if all inputs are connected A, B, C, (alpha), (beta)
84
    std::unordered_map<std::string, const data_flow::Memlet*> memlets;
×
85
    for (auto& input : this->inputs_) {
×
86
        bool found = false;
×
87
        for (auto& iedge : graph.in_edges(*this)) {
×
88
            if (iedge.dst_conn() == input) {
×
89
                found = true;
×
90
                memlets[input] = &iedge;
×
91
                break;
×
92
            }
93
        }
94
        if (!found && (input == "A" || input == "B" || input == "C")) {
×
95
            throw InvalidSDFGException("GEMMNode input " + input + " not found");
×
96
        }
97
    }
98

99
    // Check if output is connected to C
100
    auto& oedge = *graph.out_edges(*this).begin();
×
101
    if (oedge.src_conn() != this->outputs_.at(0)) {
×
102
        throw InvalidSDFGException("GEMMNode output " + this->outputs_.at(0) + " not found");
×
103
    }
104

105
    // Check dimensions of A, B, C
106
    auto& a_memlet = memlets.at("A");
×
107
    auto& a_subset_begin = a_memlet->begin_subset();
×
108
    auto& a_subset_end = a_memlet->end_subset();
×
109
    if (a_subset_begin.size() != 1) {
×
110
        throw InvalidSDFGException("GEMMNode input A must have 1 dimensions");
×
111
    }
112
    data_flow::Subset a_dims;
×
113
    for (size_t i = 0; i < a_subset_begin.size(); i++) {
×
114
        a_dims.push_back(symbolic::sub(a_subset_end[i], a_subset_begin[i]));
×
115
    }
×
116

117
    auto& b_memlet = memlets.at("B");
×
118
    auto& b_subset_begin = b_memlet->begin_subset();
×
119
    auto& b_subset_end = b_memlet->end_subset();
×
120
    if (b_subset_begin.size() != 1) {
×
121
        throw InvalidSDFGException("GEMMNode input B must have 1 dimensions");
×
122
    }
123
    data_flow::Subset b_dims;
×
124
    for (size_t i = 0; i < b_subset_begin.size(); i++) {
×
125
        b_dims.push_back(symbolic::sub(b_subset_end[i], b_subset_begin[i]));
×
126
    }
×
127

128
    auto& c_memlet = memlets.at("C");
×
129
    auto& c_subset_begin = c_memlet->begin_subset();
×
130
    auto& c_subset_end = c_memlet->end_subset();
×
131
    if (c_subset_begin.size() != 1) {
×
132
        throw InvalidSDFGException("GEMMNode input C must have 1 dimensions");
×
133
    }
134
    data_flow::Subset c_dims;
×
135
    for (size_t i = 0; i < c_subset_begin.size(); i++) {
×
136
        c_dims.push_back(symbolic::sub(c_subset_end[i], c_subset_begin[i]));
×
137
    }
×
138

139
    // TODO: Check if dimensions of A, B, C are valid
140
}
×
141

142
types::PrimitiveType GEMMNode::scalar_primitive() const {
1✔
143
    switch (this->precision_) {
1✔
144
        case BLAS_Precision::s:
145
            return types::PrimitiveType::Float;
1✔
146
        case BLAS_Precision::d:
147
            return types::PrimitiveType::Double;
×
148
        case BLAS_Precision::h:
149
            return types::PrimitiveType::Half;
×
150
        default:
151
            return types::PrimitiveType::Void;
×
152
    }
153
}
1✔
154

155
bool GEMMNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
156
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
157

158
    auto& dataflow = this->get_parent();
1✔
159
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
160
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
1✔
161
    int index = parent.index(block);
1✔
162
    auto& transition = parent.at(index).second;
1✔
163

164
    if (trans_a_ != BLAS_Transpose::No || trans_b_ != BLAS_Transpose::No) {
1✔
165
        return false;
×
166
    }
167

168
    auto& alpha = this->alpha();
1✔
169
    auto& beta = this->beta();
1✔
170

171
    auto primitive_type = scalar_primitive();
1✔
172
    if (primitive_type == types::PrimitiveType::Void) {
1✔
173
        return false;
×
174
    }
175

176
    types::Scalar scalar_type(primitive_type);
1✔
177

178
    auto in_edges = dataflow.in_edges(*this);
1✔
179
    auto in_edges_it = in_edges.begin();
1✔
180

181
    data_flow::Memlet* iedge_a = nullptr;
1✔
182
    data_flow::Memlet* iedge_b = nullptr;
1✔
183
    data_flow::Memlet* iedge_c = nullptr;
1✔
184
    data_flow::Memlet* alpha_edge = nullptr;
1✔
185
    data_flow::Memlet* beta_edge = nullptr;
1✔
186
    while (in_edges_it != in_edges.end()) {
4✔
187
        auto& edge = *in_edges_it;
3✔
188
        auto dst_conn = edge.dst_conn();
3✔
189
        if (dst_conn == "A") {
3✔
190
            iedge_a = &edge;
1✔
191
        } else if (dst_conn == "B") {
3✔
192
            iedge_b = &edge;
1✔
193
        } else if (dst_conn == "C") {
2✔
194
            iedge_c = &edge;
1✔
195
        } else if (dst_conn == alpha) {
1✔
196
            alpha_edge = &edge;
×
197
        } else if (dst_conn == beta) {
×
198
            beta_edge = &edge;
×
199
        } else {
×
200
            throw InvalidSDFGException("GEMMNode has unexpected input: " + dst_conn);
×
201
        }
202
        ++in_edges_it;
3✔
203
    }
3✔
204

205
    auto& oedge = *dataflow.out_edges(*this).begin();
1✔
206

207
    // Checks if legal
208
    auto* input_node_a = dynamic_cast<data_flow::AccessNode*>(&iedge_a->src());
1✔
209
    auto* input_node_b = dynamic_cast<data_flow::AccessNode*>(&iedge_b->src());
1✔
210
    auto* input_node_c = dynamic_cast<data_flow::AccessNode*>(&iedge_c->src());
1✔
211
    auto* output_node = dynamic_cast<data_flow::AccessNode*>(&oedge.dst());
1✔
212
    data_flow::AccessNode* alpha_node = nullptr;
1✔
213
    data_flow::AccessNode* beta_node = nullptr;
1✔
214

215
    if (alpha_edge) {
1✔
216
        alpha_node = dynamic_cast<data_flow::AccessNode*>(&alpha_edge->src());
×
217
    }
×
218
    if (beta_edge) {
1✔
219
        beta_node = dynamic_cast<data_flow::AccessNode*>(&beta_edge->src());
×
220
    }
×
221

222
    // we must be the only thing in this block, as we do not support splitting a block into pre, expanded lib-node, post
223
    if (!input_node_a || dataflow.in_degree(*input_node_a) != 0 || !input_node_b ||
2✔
224
        dataflow.in_degree(*input_node_b) != 0 || !input_node_c || dataflow.in_degree(*input_node_c) != 0 ||
1✔
225
        !output_node || dataflow.out_degree(*output_node) != 0) {
1✔
226
        return false; // data nodes are not standalone
×
227
    }
228
    if ((alpha_node && dataflow.in_degree(*alpha_node) != 0) || (beta_node && dataflow.in_degree(*beta_node) != 0)) {
1✔
229
        return false; // alpha and beta are not standalone
×
230
    }
231
    for (auto* nd : dataflow.data_nodes()) {
5✔
232
        if (nd != input_node_a && nd != input_node_b && nd != input_node_c && nd != output_node &&
4✔
233
            (!alpha_node || nd != alpha_node) && (!beta_node || nd != beta_node)) {
×
234
            return false; // there are other nodes in here that we could not preserve correctly
×
235
        }
236
    }
237

238
    auto& A_var = input_node_a->data();
1✔
239
    auto& B_var = input_node_b->data();
1✔
240
    auto& C_in_var = input_node_c->data();
1✔
241
    auto& C_out_var = output_node->data();
1✔
242

243

244
    // Add new graph after the current block
245
    auto& new_sequence = builder.add_sequence_before(
2✔
246
        parent, block, transition.assignments(), builder.debug_info().get_region(block.debug_info().indices())
1✔
247
    );
248

249
    // Add maps
250
    std::vector<symbolic::Expression> indvar_ends{this->m(), this->n(), this->k()};
1✔
251
    auto& begin_subsets_out = oedge.begin_subset();
1✔
252
    auto& end_subsets_out = oedge.end_subset();
1✔
253
    auto& begin_subsets_in_a = iedge_a->begin_subset();
1✔
254
    auto& end_subsets_in_a = iedge_a->end_subset();
1✔
255
    data_flow::Subset new_subset;
1✔
256
    structured_control_flow::Sequence* last_scope = &new_sequence;
1✔
257
    structured_control_flow::Map* last_map = nullptr;
1✔
258
    structured_control_flow::Map* output_loop = nullptr;
1✔
259
    std::vector<std::string> indvar_names{"_i", "_j", "_k"};
1✔
260

261
    std::string sum_var = builder.find_new_name("_sum");
1✔
262
    builder.add_container(sum_var, scalar_type);
1✔
263

264
    for (size_t i = 0; i < 3; i++) {
4✔
265
        auto dim_begin = symbolic::zero();
3✔
266
        auto& dim_end = indvar_ends[i];
3✔
267

268
        std::string indvar_str = builder.find_new_name(indvar_names[i]);
3✔
269
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
3✔
270

271
        auto indvar = symbolic::symbol(indvar_str);
3✔
272
        auto init = dim_begin;
3✔
273
        auto update = symbolic::add(indvar, symbolic::one());
3✔
274
        auto condition = symbolic::Lt(indvar, dim_end);
3✔
275
        last_map = &builder.add_map(
6✔
276
            *last_scope,
3✔
277
            indvar,
278
            condition,
279
            init,
3✔
280
            update,
281
            structured_control_flow::ScheduleType_Sequential::create(),
3✔
282
            {},
3✔
283
            builder.subject().debug_info().get_region(block.debug_info().indices())
3✔
284
        );
285
        last_scope = &last_map->root();
3✔
286

287
        if (i == 1) {
3✔
288
            output_loop = last_map;
1✔
289
        }
1✔
290

291
        new_subset.push_back(indvar);
3✔
292
    }
3✔
293

294

295
    // Add code
296
    auto& init_block = builder.add_block_before(
2✔
297
        output_loop->root(), *last_map, {}, builder.debug_info().get_region(block.debug_info().indices())
1✔
298
    );
299
    auto& sum_init =
1✔
300
        builder.add_access(init_block, sum_var, builder.debug_info().get_region(block.debug_info().indices()));
1✔
301

302
    auto& init_tasklet = builder.add_tasklet(
2✔
303
        init_block,
1✔
304
        data_flow::assign,
305
        "_out",
1✔
306
        {"0.0"},
1✔
307
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
308
    );
309

310
    builder.add_computational_memlet(
2✔
311
        init_block,
1✔
312
        init_tasklet,
1✔
313
        "_out",
1✔
314
        sum_init,
1✔
315
        {},
1✔
316
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
317
    );
318

319
    auto& code_block =
1✔
320
        builder.add_block(*last_scope, {}, builder.subject().debug_info().get_region(block.debug_info().indices()));
1✔
321
    auto& input_node_a_new = builder.add_access(
2✔
322
        code_block, A_var, builder.subject().debug_info().get_region(input_node_a->debug_info().indices())
1✔
323
    );
324
    auto& input_node_b_new = builder.add_access(
2✔
325
        code_block, B_var, builder.subject().debug_info().get_region(input_node_b->debug_info().indices())
1✔
326
    );
327

328
    auto& core_fma = builder.add_tasklet(
2✔
329
        code_block,
1✔
330
        data_flow::fma,
331
        "_out",
1✔
332
        {"_in1", "_in2", "_in3"},
1✔
333
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
334
    );
335
    auto& sum_in =
1✔
336
        builder.add_access(code_block, sum_var, builder.subject().debug_info().get_region(block.debug_info().indices()));
1✔
337
    auto& sum_out =
1✔
338
        builder.add_access(code_block, sum_var, builder.subject().debug_info().get_region(block.debug_info().indices()));
1✔
339
    builder.add_computational_memlet(
2✔
340
        code_block, sum_in, core_fma, "_in3", {}, builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
341
    );
342

343
    symbolic::Expression a_idx = symbolic::add(symbolic::mul(lda(), new_subset[0]), new_subset[2]);
1✔
344
    builder.add_computational_memlet(
2✔
345
        code_block,
1✔
346
        input_node_a_new,
1✔
347
        core_fma,
1✔
348
        "_in1",
1✔
349
        {a_idx},
1✔
350
        iedge_a->base_type(),
1✔
351
        builder.subject().debug_info().get_region(iedge_a->debug_info().indices())
1✔
352
    );
353
    symbolic::Expression b_idx = symbolic::add(symbolic::mul(ldb(), new_subset[2]), new_subset[1]);
1✔
354
    builder.add_computational_memlet(
2✔
355
        code_block,
1✔
356
        input_node_b_new,
1✔
357
        core_fma,
1✔
358
        "_in2",
1✔
359
        {b_idx},
1✔
360
        iedge_b->base_type(),
1✔
361
        builder.subject().debug_info().get_region(iedge_b->debug_info().indices())
1✔
362
    );
363
    builder.add_computational_memlet(
2✔
364
        code_block,
1✔
365
        core_fma,
1✔
366
        "_out",
1✔
367
        sum_out,
1✔
368
        {},
1✔
369
        builder.subject().debug_info().get_region(oedge.debug_info().indices())
1✔
370
    );
371

372
    auto& flush_block = builder.add_block_after(
2✔
373
        output_loop->root(), *last_map, {}, builder.debug_info().get_region(block.debug_info().indices())
1✔
374
    );
375
    auto& sum_final =
1✔
376
        builder.add_access(flush_block, sum_var, builder.debug_info().get_region(block.debug_info().indices()));
1✔
377
    auto& input_node_c_new =
1✔
378
        builder.add_access(flush_block, C_in_var, builder.debug_info().get_region(input_node_c->debug_info().indices()));
1✔
379
    symbolic::Expression c_idx = symbolic::add(symbolic::mul(ldc(), new_subset[0]), new_subset[1]);
1✔
380

381
    auto& scale_sum_tasklet = builder.add_tasklet(
2✔
382
        flush_block,
1✔
383
        data_flow::mul,
384
        "_out",
1✔
385
        {"_in1", alpha},
1✔
386
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
387
    );
388
    builder.add_computational_memlet(
2✔
389
        flush_block,
1✔
390
        sum_final,
1✔
391
        scale_sum_tasklet,
1✔
392
        "_in1",
1✔
393
        {},
1✔
394
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
395
    );
396
    if (alpha_node) {
1✔
NEW
397
        auto& alpha_node_new = builder.add_access(
×
NEW
398
            flush_block, alpha_node->data(), builder.subject().debug_info().get_region(block.debug_info().indices())
×
399
        );
NEW
400
        builder.add_computational_memlet(
×
NEW
401
            flush_block,
×
NEW
402
            scale_sum_tasklet,
×
NEW
403
            alpha,
×
NEW
404
            alpha_node_new,
×
NEW
405
            {},
×
NEW
406
            builder.subject().debug_info().get_region(block.debug_info().indices())
×
407
        );
UNCOV
408
    }
×
409

410
    std::string scaled_sum_temp = builder.find_new_name("scaled_sum_temp");
1✔
411
    builder.add_container(scaled_sum_temp, scalar_type);
1✔
412
    auto& scaled_sum_final = builder.add_access(
2✔
413
        flush_block, scaled_sum_temp, builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
414
    );
415
    builder.add_computational_memlet(
2✔
416
        flush_block,
1✔
417
        scale_sum_tasklet,
1✔
418
        "_out",
1✔
419
        scaled_sum_final,
1✔
420
        {},
1✔
421
        scalar_type,
422
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
423
    );
424

425
    auto& scale_input_tasklet = builder.add_tasklet(
2✔
426
        flush_block,
1✔
427
        data_flow::mul,
428
        "_out",
1✔
429
        {"_in1", beta},
1✔
430
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
431
    );
432
    builder.add_computational_memlet(
2✔
433
        flush_block,
1✔
434
        input_node_c_new,
1✔
435
        scale_input_tasklet,
1✔
436
        "_in1",
1✔
437
        {c_idx},
1✔
438
        iedge_c->base_type(),
1✔
439
        builder.subject().debug_info().get_region(iedge_c->debug_info().indices())
1✔
440
    );
441
    if (beta_node) {
1✔
NEW
442
        auto& beta_node_new = builder.add_access(
×
NEW
443
            flush_block, beta_node->data(), builder.subject().debug_info().get_region(block.debug_info().indices())
×
444
        );
445
        builder.add_computational_memlet(
×
NEW
446
            flush_block,
×
NEW
447
            scale_sum_tasklet,
×
NEW
448
            beta,
×
NEW
449
            beta_node_new,
×
NEW
450
            {},
×
451
            scalar_type,
NEW
452
            builder.subject().debug_info().get_region(block.debug_info().indices())
×
453
        );
454
    }
×
455

456
    std::string scaled_input_temp = builder.find_new_name("scaled_input_temp");
1✔
457
    builder.add_container(scaled_input_temp, scalar_type);
1✔
458
    auto& scaled_input_c = builder.add_access(
2✔
459
        flush_block, scaled_input_temp, builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
460
    );
461
    builder.add_computational_memlet(
2✔
462
        flush_block,
1✔
463
        scale_input_tasklet,
1✔
464
        "_out",
1✔
465
        scaled_input_c,
1✔
466
        {},
1✔
467
        scalar_type,
468
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
469
    );
470

471
    auto& flush_add_tasklet = builder.add_tasklet(
2✔
472
        flush_block,
1✔
473
        data_flow::add,
474
        "_out",
1✔
475
        {"_in1", "_in2"},
1✔
476
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
477
    );
478
    auto& output_node_new = builder.add_access(
2✔
479
        flush_block, C_out_var, builder.subject().debug_info().get_region(output_node->debug_info().indices())
1✔
480
    );
481
    builder.add_computational_memlet(
2✔
482
        flush_block,
1✔
483
        scaled_sum_final,
1✔
484
        flush_add_tasklet,
1✔
485
        "_in1",
1✔
486
        {},
1✔
487
        scalar_type,
488
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
489
    );
490
    builder.add_computational_memlet(
2✔
491
        flush_block,
1✔
492
        scaled_input_c,
1✔
493
        flush_add_tasklet,
1✔
494
        "_in2",
1✔
495
        {},
1✔
496
        scalar_type,
497
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
498
    );
499
    builder.add_computational_memlet(
2✔
500
        flush_block,
1✔
501
        flush_add_tasklet,
1✔
502
        "_out",
1✔
503
        output_node_new,
1✔
504
        {c_idx},
1✔
505
        iedge_c->base_type(),
1✔
506
        builder.subject().debug_info().get_region(iedge_c->debug_info().indices())
1✔
507
    );
508

509

510
    // Clean up block
511
    builder.remove_memlet(block, *iedge_a);
1✔
512
    builder.remove_memlet(block, *iedge_b);
1✔
513
    builder.remove_memlet(block, *iedge_c);
1✔
514
    if (alpha_edge) {
1✔
515
        builder.remove_memlet(block, *alpha_edge);
×
516
        builder.remove_node(block, *alpha_node);
×
517
    }
×
518
    if (beta_edge) {
1✔
519
        builder.remove_memlet(block, *beta_edge);
×
520
        builder.remove_node(block, *beta_node);
×
521
    }
×
522
    builder.remove_memlet(block, oedge);
1✔
523
    builder.remove_node(block, *input_node_a);
1✔
524
    builder.remove_node(block, *input_node_b);
1✔
525
    builder.remove_node(block, *input_node_c);
1✔
526
    builder.remove_node(block, *output_node);
1✔
527
    builder.remove_node(block, *this);
1✔
528
    builder.remove_child(parent, index + 1);
1✔
529

530
    return true;
1✔
531
}
1✔
532

533
std::unique_ptr<data_flow::DataFlowNode> GEMMNode::
534
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
535
    auto node_clone = std::unique_ptr<GEMMNode>(new GEMMNode(
×
536
        element_id,
×
537
        this->debug_info(),
×
538
        vertex,
×
539
        parent,
×
540
        this->implementation_type_,
×
541
        this->precision_,
×
542
        this->layout_,
×
543
        this->trans_a_,
×
544
        this->trans_b_,
×
545
        this->m_,
×
546
        this->n_,
×
547
        this->k_,
×
548
        this->lda_,
×
549
        this->ldb_,
×
550
        this->ldc_,
×
551
        this->alpha(),
×
552
        this->beta()
×
553
    ));
554
    return std::move(node_clone);
×
555
}
×
556

557
std::string GEMMNode::toStr() const {
1✔
558
    return LibraryNode::toStr() + "(" + static_cast<char>(precision_) + ", " +
3✔
559
           std::string(BLAS_Layout_to_short_string(layout_)) + ", " + BLAS_Transpose_to_char(trans_a_) +
3✔
560
           BLAS_Transpose_to_char(trans_b_) + ", " + m_->__str__() + ", " + n_->__str__() + ", " + k_->__str__() +
2✔
561
           ", " + lda_->__str__() + ", " + ldb_->__str__() + ", " + ldc_->__str__() + ")";
1✔
562
}
×
563

564
nlohmann::json GEMMNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
565
    const GEMMNode& gemm_node = static_cast<const GEMMNode&>(library_node);
×
566
    nlohmann::json j;
×
567

568
    serializer::JSONSerializer serializer;
×
569
    j["code"] = gemm_node.code().value();
×
570
    j["precision"] = gemm_node.precision();
×
571
    j["layout"] = gemm_node.layout();
×
572
    j["trans_a"] = gemm_node.trans_a();
×
573
    j["trans_b"] = gemm_node.trans_b();
×
574
    j["m"] = serializer.expression(gemm_node.m());
×
575
    j["n"] = serializer.expression(gemm_node.n());
×
576
    j["k"] = serializer.expression(gemm_node.k());
×
577
    j["lda"] = serializer.expression(gemm_node.lda());
×
578
    j["ldb"] = serializer.expression(gemm_node.ldb());
×
579
    j["ldc"] = serializer.expression(gemm_node.ldc());
×
580
    j["alpha"] = gemm_node.alpha();
×
581
    j["beta"] = gemm_node.beta();
×
582

583
    return j;
×
584
}
×
585

586
data_flow::LibraryNode& GEMMNodeSerializer::deserialize(
×
587
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
588
) {
589
    // Assertions for required fields
590
    assert(j.contains("element_id"));
×
591
    assert(j.contains("code"));
×
592
    assert(j.contains("debug_info"));
×
593

594
    auto code = j["code"].get<std::string>();
×
595
    if (code != LibraryNodeType_GEMM.value()) {
×
596
        throw std::runtime_error("Invalid library node code");
×
597
    }
598

599
    // Extract debug info using JSONSerializer
600
    sdfg::serializer::JSONSerializer serializer;
×
NEW
601
    DebugInfoRegion debug_info = serializer.json_to_debug_info_region(j["debug_info"], builder.debug_info());
×
602

603
    auto precision = j.at("precision").get<BLAS_Precision>();
×
604
    auto layout = j.at("layout").get<BLAS_Layout>();
×
605
    auto trans_a = j.at("trans_a").get<BLAS_Transpose>();
×
606
    auto trans_b = j.at("trans_b").get<BLAS_Transpose>();
×
607
    auto m = SymEngine::Expression(j.at("m"));
×
608
    auto n = SymEngine::Expression(j.at("n"));
×
609
    auto k = SymEngine::Expression(j.at("k"));
×
610
    auto lda = SymEngine::Expression(j.at("lda"));
×
611
    auto ldb = SymEngine::Expression(j.at("ldb"));
×
612
    auto ldc = SymEngine::Expression(j.at("ldc"));
×
613
    auto alpha = j.at("alpha").get<std::string>();
×
614
    auto beta = j.at("beta").get<std::string>();
×
615

616
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
617

618
    return builder.add_library_node<GEMMNode>(
×
619
        parent, debug_info, implementation_type, precision, layout, trans_a, trans_b, m, n, k, lda, ldb, ldc, alpha, beta
×
620
    );
621
}
×
622

623
GEMMNodeDispatcher_BLAS::GEMMNodeDispatcher_BLAS(
×
624
    codegen::LanguageExtension& language_extension,
625
    const Function& function,
626
    const data_flow::DataFlowGraph& data_flow_graph,
627
    const GEMMNode& node
628
)
629
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
630

631
void GEMMNodeDispatcher_BLAS::dispatch(
×
632
    codegen::PrettyPrinter& stream,
633
    codegen::PrettyPrinter& globals_stream,
634
    codegen::CodeSnippetFactory& library_snippet_factory
635
) {
636
    stream << "{" << std::endl;
×
637
    stream.setIndent(stream.indent() + 4);
×
638

639
    auto& gemm_node = static_cast<const GEMMNode&>(this->node_);
×
640

641
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
642
    switch (gemm_node.precision()) {
×
643
        case BLAS_Precision::h:
644
            base_type = types::Scalar(types::PrimitiveType::Half);
×
645
            break;
×
646
        case BLAS_Precision::s:
647
            base_type = types::Scalar(types::PrimitiveType::Float);
×
648
            break;
×
649
        case BLAS_Precision::d:
650
            base_type = types::Scalar(types::PrimitiveType::Double);
×
651
            break;
×
652
        default:
653
            throw std::runtime_error("Invalid BLAS_Precision value");
×
654
    }
655

656
    auto& graph = this->node_.get_parent();
×
657
    for (auto& iedge : graph.in_edges(this->node_)) {
×
658
        auto& access_node = static_cast<const data_flow::AccessNode&>(iedge.src());
×
659
        std::string name = access_node.data();
×
660
        auto& type = this->function_.type(name);
×
661

662
        stream << this->language_extension_.declaration(iedge.dst_conn(), type);
×
663
        stream << " = " << name << ";" << std::endl;
×
664
    }
×
665

666
    if (std::find(gemm_node.inputs().begin(), gemm_node.inputs().end(), "alpha") ==
×
667
        gemm_node.inputs().end()) { // TODO obsolute, must be an input!
×
668
        stream << this->language_extension_.declaration("alpha", base_type);
×
669
        stream << " = " << gemm_node.alpha() << ";" << std::endl;
×
670
    }
×
671
    if (std::find(gemm_node.inputs().begin(), gemm_node.inputs().end(), "beta") == gemm_node.inputs().end()) {
×
672
        stream << this->language_extension_.declaration("beta", base_type);
×
673
        stream << " = " << gemm_node.beta() << ";" << std::endl;
×
674
    }
×
675

676
    stream << "cblas_" << BLAS_Precision_to_string(gemm_node.precision()) << "gemm(";
×
677
    stream.setIndent(stream.indent() + 4);
×
678
    stream << BLAS_Layout_to_string(gemm_node.layout());
×
679
    stream << ", ";
×
680
    stream << BLAS_Transpose_to_string(gemm_node.trans_a());
×
681
    stream << ", ";
×
682
    stream << BLAS_Transpose_to_string(gemm_node.trans_b());
×
683
    stream << ", ";
×
684
    stream << this->language_extension_.expression(gemm_node.m());
×
685
    stream << ", ";
×
686
    stream << this->language_extension_.expression(gemm_node.n());
×
687
    stream << ", ";
×
688
    stream << this->language_extension_.expression(gemm_node.k());
×
689
    stream << ", ";
×
690
    stream << "alpha";
×
691
    stream << ", ";
×
692
    stream << "A";
×
693
    stream << ", ";
×
694
    stream << this->language_extension_.expression(gemm_node.lda());
×
695
    stream << ", ";
×
696
    stream << "B";
×
697
    stream << ", ";
×
698
    stream << this->language_extension_.expression(gemm_node.ldb());
×
699
    stream << ", ";
×
700
    stream << "beta";
×
701
    stream << ", ";
×
702
    stream << "C";
×
703
    stream << ", ";
×
704
    stream << this->language_extension_.expression(gemm_node.ldc());
×
705

706
    stream.setIndent(stream.indent() - 4);
×
707
    stream << ");" << std::endl;
×
708

709
    stream.setIndent(stream.indent() - 4);
×
710
    stream << "}" << std::endl;
×
711
}
×
712

713
GEMMNodeDispatcher_CUBLAS::GEMMNodeDispatcher_CUBLAS(
×
714
    codegen::LanguageExtension& language_extension,
715
    const Function& function,
716
    const data_flow::DataFlowGraph& data_flow_graph,
717
    const GEMMNode& node
718
)
719
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
720

721
void GEMMNodeDispatcher_CUBLAS::dispatch(
×
722
    codegen::PrettyPrinter& stream,
723
    codegen::PrettyPrinter& globals_stream,
724
    codegen::CodeSnippetFactory& library_snippet_factory
725
) {
726
    throw std::runtime_error("GEMMNodeDispatcher_CUBLAS not implemented");
×
727
}
×
728

729
} // namespace blas
730
} // namespace math
731
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc