• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 19917337940

01 Dec 2025 04:38PM UTC coverage: 61.822% (-0.06%) from 61.885%
19917337940

push

github

web-flow
Merge pull request #369 from daisytuner/cleaner-flop-analysis-api

Cleaner API to FlopAnalysis to hide its internal artifacts when we on…

11 of 56 new or added lines in 7 files covered. (19.64%)

3 existing lines in 2 files now uncovered.

11254 of 18204 relevant lines covered (61.82%)

110.69 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.35
/src/data_flow/library_nodes/math/blas/gemm.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/gemm.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
#include "sdfg/analysis/scope_analysis.h"
7

8
namespace sdfg {
9
namespace math {
10
namespace blas {
11

12
GEMMNode::GEMMNode(
1✔
13
    size_t element_id,
14
    const DebugInfo& debug_info,
15
    const graph::Vertex vertex,
16
    data_flow::DataFlowGraph& parent,
17
    const data_flow::ImplementationType& implementation_type,
18
    const BLAS_Precision& precision,
19
    const BLAS_Layout& layout,
20
    const BLAS_Transpose& trans_a,
21
    const BLAS_Transpose& trans_b,
22
    symbolic::Expression m,
23
    symbolic::Expression n,
24
    symbolic::Expression k,
25
    symbolic::Expression lda,
26
    symbolic::Expression ldb,
27
    symbolic::Expression ldc
28
)
29
    : BLASNode(
1✔
30
          element_id,
1✔
31
          debug_info,
1✔
32
          vertex,
1✔
33
          parent,
1✔
34
          LibraryNodeType_GEMM,
35
          {"C"},
1✔
36
          {"A", "B", "C", "alpha", "beta"},
1✔
37
          implementation_type,
1✔
38
          precision
1✔
39
      ),
40
      layout_(layout), trans_a_(trans_a), trans_b_(trans_b), m_(m), n_(n), k_(k), lda_(lda), ldb_(ldb), ldc_(ldc) {}
1✔
41

42
BLAS_Layout GEMMNode::layout() const { return this->layout_; };
×
43

44
BLAS_Transpose GEMMNode::trans_a() const { return this->trans_a_; };
×
45

46
BLAS_Transpose GEMMNode::trans_b() const { return this->trans_b_; };
×
47

48
symbolic::Expression GEMMNode::m() const { return this->m_; };
1✔
49

50
symbolic::Expression GEMMNode::n() const { return this->n_; };
1✔
51

52
symbolic::Expression GEMMNode::k() const { return this->k_; };
1✔
53

54
symbolic::Expression GEMMNode::lda() const { return this->lda_; };
1✔
55

56
symbolic::Expression GEMMNode::ldb() const { return this->ldb_; };
1✔
57

58
symbolic::Expression GEMMNode::ldc() const { return this->ldc_; };
1✔
59

60
void GEMMNode::validate(const Function& function) const {}
1✔
61

62
bool GEMMNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
63
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
64

65
    auto& dataflow = this->get_parent();
1✔
66
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
67
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
1✔
68
    int index = parent.index(block);
1✔
69
    auto& transition = parent.at(index).second;
1✔
70

71
    if (trans_a_ != BLAS_Transpose::No || trans_b_ != BLAS_Transpose::No) {
1✔
72
        return false;
×
73
    }
74

75
    auto primitive_type = scalar_primitive();
1✔
76
    if (primitive_type == types::PrimitiveType::Void) {
1✔
77
        return false;
×
78
    }
79

80
    types::Scalar scalar_type(primitive_type);
1✔
81

82
    auto in_edges = dataflow.in_edges(*this);
1✔
83
    auto in_edges_it = in_edges.begin();
1✔
84

85
    data_flow::Memlet* iedge_a = nullptr;
1✔
86
    data_flow::Memlet* iedge_b = nullptr;
1✔
87
    data_flow::Memlet* iedge_c = nullptr;
1✔
88
    data_flow::Memlet* alpha_edge = nullptr;
1✔
89
    data_flow::Memlet* beta_edge = nullptr;
1✔
90
    while (in_edges_it != in_edges.end()) {
6✔
91
        auto& edge = *in_edges_it;
5✔
92
        auto dst_conn = edge.dst_conn();
5✔
93
        if (dst_conn == "A") {
5✔
94
            iedge_a = &edge;
1✔
95
        } else if (dst_conn == "B") {
5✔
96
            iedge_b = &edge;
1✔
97
        } else if (dst_conn == "C") {
4✔
98
            iedge_c = &edge;
1✔
99
        } else if (dst_conn == "alpha") {
3✔
100
            alpha_edge = &edge;
1✔
101
        } else if (dst_conn == "beta") {
2✔
102
            beta_edge = &edge;
1✔
103
        } else {
1✔
104
            throw InvalidSDFGException("GEMMNode has unexpected input: " + dst_conn);
×
105
        }
106
        ++in_edges_it;
5✔
107
    }
5✔
108

109
    auto& oedge = *dataflow.out_edges(*this).begin();
1✔
110

111
    // Checks if legal
112
    auto* input_node_a = static_cast<data_flow::AccessNode*>(&iedge_a->src());
1✔
113
    auto* input_node_b = static_cast<data_flow::AccessNode*>(&iedge_b->src());
1✔
114
    auto* input_node_c = static_cast<data_flow::AccessNode*>(&iedge_c->src());
1✔
115
    auto* output_node = static_cast<data_flow::AccessNode*>(&oedge.dst());
1✔
116
    auto* alpha_node = static_cast<data_flow::AccessNode*>(&alpha_edge->src());
1✔
117
    auto* beta_node = static_cast<data_flow::AccessNode*>(&beta_edge->src());
1✔
118

119
    // we must be the only thing in this block, as we do not support splitting a block into pre, expanded lib-node, post
120
    if (!input_node_a || dataflow.in_degree(*input_node_a) != 0 || !input_node_b ||
2✔
121
        dataflow.in_degree(*input_node_b) != 0 || !input_node_c || dataflow.in_degree(*input_node_c) != 0 ||
1✔
122
        !output_node || dataflow.out_degree(*output_node) != 0) {
1✔
123
        return false; // data nodes are not standalone
×
124
    }
125
    if (dataflow.in_degree(*alpha_node) != 0 || dataflow.in_degree(*beta_node) != 0) {
1✔
126
        return false; // alpha and beta are not standalone
×
127
    }
128
    for (auto* nd : dataflow.data_nodes()) {
7✔
129
        if (nd != input_node_a && nd != input_node_b && nd != input_node_c && nd != output_node &&
7✔
130
            (!alpha_node || nd != alpha_node) && (!beta_node || nd != beta_node)) {
2✔
131
            return false; // there are other nodes in here that we could not preserve correctly
×
132
        }
133
    }
134

135
    auto& A_var = input_node_a->data();
1✔
136
    auto& B_var = input_node_b->data();
1✔
137
    auto& C_in_var = input_node_c->data();
1✔
138
    auto& C_out_var = output_node->data();
1✔
139

140

141
    // Add new graph after the current block
142
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
1✔
143

144
    // Add maps
145
    std::vector<symbolic::Expression> indvar_ends{this->m(), this->n(), this->k()};
1✔
146
    data_flow::Subset new_subset;
1✔
147
    structured_control_flow::Sequence* last_scope = &new_sequence;
1✔
148
    structured_control_flow::Map* last_map = nullptr;
1✔
149
    structured_control_flow::Map* output_loop = nullptr;
1✔
150
    std::vector<std::string> indvar_names{"_i", "_j", "_k"};
1✔
151

152
    std::string sum_var = builder.find_new_name("_sum");
1✔
153
    builder.add_container(sum_var, scalar_type);
1✔
154

155
    for (size_t i = 0; i < 3; i++) {
4✔
156
        auto dim_begin = symbolic::zero();
3✔
157
        auto& dim_end = indvar_ends[i];
3✔
158

159
        std::string indvar_str = builder.find_new_name(indvar_names[i]);
3✔
160
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
3✔
161

162
        auto indvar = symbolic::symbol(indvar_str);
3✔
163
        auto init = dim_begin;
3✔
164
        auto update = symbolic::add(indvar, symbolic::one());
3✔
165
        auto condition = symbolic::Lt(indvar, dim_end);
3✔
166
        last_map = &builder.add_map(
6✔
167
            *last_scope,
3✔
168
            indvar,
3✔
169
            condition,
3✔
170
            init,
3✔
171
            update,
3✔
172
            structured_control_flow::ScheduleType_Sequential::create(),
3✔
173
            {},
3✔
174
            block.debug_info()
3✔
175
        );
176
        last_scope = &last_map->root();
3✔
177

178
        if (i == 1) {
3✔
179
            output_loop = last_map;
1✔
180
        }
1✔
181

182
        new_subset.push_back(indvar);
3✔
183
    }
3✔
184

185

186
    // Add code
187
    auto& init_block = builder.add_block_before(output_loop->root(), *last_map, {}, block.debug_info());
1✔
188
    auto& sum_init = builder.add_access(init_block, sum_var, block.debug_info());
1✔
189

190
    auto& zero_node = builder.add_constant(init_block, "0.0", alpha_edge->base_type(), block.debug_info());
1✔
191
    auto& init_tasklet = builder.add_tasklet(init_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
1✔
192
    builder.add_computational_memlet(init_block, zero_node, init_tasklet, "_in", {}, block.debug_info());
1✔
193
    builder.add_computational_memlet(init_block, init_tasklet, "_out", sum_init, {}, block.debug_info());
1✔
194

195
    auto& code_block = builder.add_block(*last_scope, {}, block.debug_info());
1✔
196
    auto& input_node_a_new = builder.add_access(code_block, A_var, input_node_a->debug_info());
1✔
197
    auto& input_node_b_new = builder.add_access(code_block, B_var, input_node_b->debug_info());
1✔
198

199
    auto& core_fma =
1✔
200
        builder.add_tasklet(code_block, data_flow::fp_fma, "_out", {"_in1", "_in2", "_in3"}, block.debug_info());
1✔
201
    auto& sum_in = builder.add_access(code_block, sum_var, block.debug_info());
1✔
202
    auto& sum_out = builder.add_access(code_block, sum_var, block.debug_info());
1✔
203
    builder.add_computational_memlet(code_block, sum_in, core_fma, "_in3", {}, block.debug_info());
1✔
204

205
    symbolic::Expression a_idx = symbolic::add(symbolic::mul(lda(), new_subset[0]), new_subset[2]);
1✔
206
    builder.add_computational_memlet(
2✔
207
        code_block, input_node_a_new, core_fma, "_in1", {a_idx}, iedge_a->base_type(), iedge_a->debug_info()
1✔
208
    );
209
    symbolic::Expression b_idx = symbolic::add(symbolic::mul(ldb(), new_subset[2]), new_subset[1]);
1✔
210
    builder.add_computational_memlet(
2✔
211
        code_block, input_node_b_new, core_fma, "_in2", {b_idx}, iedge_b->base_type(), iedge_b->debug_info()
1✔
212
    );
213
    builder.add_computational_memlet(code_block, core_fma, "_out", sum_out, {}, oedge.debug_info());
1✔
214

215
    auto& flush_block = builder.add_block_after(output_loop->root(), *last_map, {}, block.debug_info());
1✔
216
    auto& sum_final = builder.add_access(flush_block, sum_var, block.debug_info());
1✔
217
    auto& input_node_c_new = builder.add_access(flush_block, C_in_var, input_node_c->debug_info());
1✔
218
    symbolic::Expression c_idx = symbolic::add(symbolic::mul(ldc(), new_subset[0]), new_subset[1]);
1✔
219

220
    auto& scale_sum_tasklet =
1✔
221
        builder.add_tasklet(flush_block, data_flow::TaskletCode::fp_mul, "_out", {"_in1", "_in2"}, block.debug_info());
1✔
222
    builder.add_computational_memlet(flush_block, sum_final, scale_sum_tasklet, "_in1", {}, block.debug_info());
1✔
223
    if (auto const_node = dynamic_cast<data_flow::ConstantNode*>(alpha_node)) {
1✔
224
        auto& alpha_node_new =
1✔
225
            builder.add_constant(flush_block, const_node->data(), const_node->type(), block.debug_info());
1✔
226
        builder.add_computational_memlet(flush_block, alpha_node_new, scale_sum_tasklet, "_in2", {}, block.debug_info());
1✔
227
    } else {
1✔
228
        auto& alpha_node_new = builder.add_access(flush_block, alpha_node->data(), block.debug_info());
×
229
        builder.add_computational_memlet(flush_block, alpha_node_new, scale_sum_tasklet, "_in2", {}, block.debug_info());
×
230
    }
231

232
    std::string scaled_sum_temp = builder.find_new_name("scaled_sum_temp");
1✔
233
    builder.add_container(scaled_sum_temp, scalar_type);
1✔
234
    auto& scaled_sum_final = builder.add_access(flush_block, scaled_sum_temp, block.debug_info());
1✔
235
    builder.add_computational_memlet(
2✔
236
        flush_block, scale_sum_tasklet, "_out", scaled_sum_final, {}, scalar_type, block.debug_info()
1✔
237
    );
238

239
    auto& scale_input_tasklet =
1✔
240
        builder.add_tasklet(flush_block, data_flow::TaskletCode::fp_mul, "_out", {"_in1", "_in2"}, block.debug_info());
1✔
241
    builder.add_computational_memlet(
2✔
242
        flush_block, input_node_c_new, scale_input_tasklet, "_in1", {c_idx}, iedge_c->base_type(), iedge_c->debug_info()
1✔
243
    );
244
    if (auto const_node = dynamic_cast<data_flow::ConstantNode*>(beta_node)) {
1✔
245
        auto& beta_node_new =
1✔
246
            builder.add_constant(flush_block, const_node->data(), const_node->type(), block.debug_info());
1✔
247
        builder
2✔
248
            .add_computational_memlet(flush_block, beta_node_new, scale_input_tasklet, "_in2", {}, block.debug_info());
1✔
249
    } else {
1✔
250
        auto& beta_node_new = builder.add_access(flush_block, beta_node->data(), block.debug_info());
×
251
        builder
×
252
            .add_computational_memlet(flush_block, beta_node_new, scale_input_tasklet, "_in2", {}, block.debug_info());
×
253
    }
254

255
    std::string scaled_input_temp = builder.find_new_name("scaled_input_temp");
1✔
256
    builder.add_container(scaled_input_temp, scalar_type);
1✔
257
    auto& scaled_input_c = builder.add_access(flush_block, scaled_input_temp, block.debug_info());
1✔
258
    builder.add_computational_memlet(
2✔
259
        flush_block, scale_input_tasklet, "_out", scaled_input_c, {}, scalar_type, block.debug_info()
1✔
260
    );
261

262
    auto& flush_add_tasklet =
1✔
263
        builder.add_tasklet(flush_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, block.debug_info());
1✔
264
    auto& output_node_new = builder.add_access(flush_block, C_out_var, output_node->debug_info());
1✔
265
    builder.add_computational_memlet(
2✔
266
        flush_block, scaled_sum_final, flush_add_tasklet, "_in1", {}, scalar_type, block.debug_info()
1✔
267
    );
268
    builder.add_computational_memlet(
2✔
269
        flush_block, scaled_input_c, flush_add_tasklet, "_in2", {}, scalar_type, block.debug_info()
1✔
270
    );
271
    builder.add_computational_memlet(
2✔
272
        flush_block, flush_add_tasklet, "_out", output_node_new, {c_idx}, iedge_c->base_type(), iedge_c->debug_info()
1✔
273
    );
274

275

276
    // Clean up block
277
    builder.remove_memlet(block, *iedge_a);
1✔
278
    builder.remove_memlet(block, *iedge_b);
1✔
279
    builder.remove_memlet(block, *iedge_c);
1✔
280
    builder.remove_memlet(block, *alpha_edge);
1✔
281
    builder.remove_node(block, *alpha_node);
1✔
282
    builder.remove_memlet(block, *beta_edge);
1✔
283
    builder.remove_node(block, *beta_node);
1✔
284
    builder.remove_memlet(block, oedge);
1✔
285
    builder.remove_node(block, *input_node_a);
1✔
286
    builder.remove_node(block, *input_node_b);
1✔
287
    builder.remove_node(block, *input_node_c);
1✔
288
    builder.remove_node(block, *output_node);
1✔
289
    builder.remove_node(block, *this);
1✔
290
    builder.remove_child(parent, index + 1);
1✔
291

292
    return true;
1✔
293
}
1✔
294

NEW
295
symbolic::Expression GEMMNode::flop() const {
×
NEW
296
    return flops(symbolic::__true__(), symbolic::__true__(), symbolic::__true__(), symbolic::__true__());
×
NEW
297
}
×
298

NEW
299
symbolic::Expression GEMMNode::flops(
×
300
    symbolic::Condition alpha_non_zero,
301
    symbolic::Condition alpha_non_ident,
302
    symbolic::Condition beta_non_zero,
303
    symbolic::Condition beta_non_ident
304
) const {
NEW
305
    auto res_elems = symbolic::mul(this->m_, this->n_);
×
306

307
    // conditional on alpha != 0.0
NEW
308
    auto mm_mul_ops = symbolic::mul(symbolic::mul(res_elems, this->k_), alpha_non_zero);
×
NEW
309
    auto mm_sum_ops = symbolic::mul(symbolic::mul(res_elems, symbolic::sub(this->k_, symbolic::one())), alpha_non_zero);
×
310
    // conditional on alpha != 1.0 && alpha != 0.0
NEW
311
    auto mm_alpha_scale_ops = symbolic::mul(res_elems, symbolic::And(alpha_non_ident, alpha_non_zero));
×
312
    // conditional on beta != 1.0 && beta != 0.0
NEW
313
    auto mm_beta_scale_ops = symbolic::mul(res_elems, symbolic::And(beta_non_ident, beta_non_zero));
×
NEW
314
    auto mm_beta_scaled_sum_ops = symbolic::mul(res_elems, beta_non_zero);
×
NEW
315
    auto mul_ops = symbolic::add(mm_mul_ops, symbolic::add(mm_alpha_scale_ops, mm_beta_scale_ops));
×
NEW
316
    auto add_ops = symbolic::add(mm_sum_ops, mm_beta_scaled_sum_ops);
×
NEW
317
    return symbolic::add(mul_ops, add_ops);
×
NEW
318
}
×
319

320
std::unique_ptr<data_flow::DataFlowNode> GEMMNode::
321
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
322
    auto node_clone = std::unique_ptr<GEMMNode>(new GEMMNode(
×
323
        element_id,
×
324
        this->debug_info(),
×
325
        vertex,
×
326
        parent,
×
327
        this->implementation_type_,
×
328
        this->precision_,
×
329
        this->layout_,
×
330
        this->trans_a_,
×
331
        this->trans_b_,
×
332
        this->m_,
×
333
        this->n_,
×
334
        this->k_,
×
335
        this->lda_,
×
336
        this->ldb_,
×
337
        this->ldc_
×
338
    ));
339
    return std::move(node_clone);
×
340
}
×
341

342
std::string GEMMNode::toStr() const {
×
343
    return LibraryNode::toStr() + "(" + static_cast<char>(precision_) + ", " +
×
344
           std::string(BLAS_Layout_to_short_string(layout_)) + ", " + BLAS_Transpose_to_char(trans_a_) +
×
345
           BLAS_Transpose_to_char(trans_b_) + ", " + m_->__str__() + ", " + n_->__str__() + ", " + k_->__str__() +
×
346
           ", " + lda_->__str__() + ", " + ldb_->__str__() + ", " + ldc_->__str__() + ")";
×
347
}
×
348

349
nlohmann::json GEMMNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
350
    const GEMMNode& gemm_node = static_cast<const GEMMNode&>(library_node);
×
351
    nlohmann::json j;
×
352

353
    serializer::JSONSerializer serializer;
×
354
    j["code"] = gemm_node.code().value();
×
355
    j["precision"] = gemm_node.precision();
×
356
    j["layout"] = gemm_node.layout();
×
357
    j["trans_a"] = gemm_node.trans_a();
×
358
    j["trans_b"] = gemm_node.trans_b();
×
359
    j["m"] = serializer.expression(gemm_node.m());
×
360
    j["n"] = serializer.expression(gemm_node.n());
×
361
    j["k"] = serializer.expression(gemm_node.k());
×
362
    j["lda"] = serializer.expression(gemm_node.lda());
×
363
    j["ldb"] = serializer.expression(gemm_node.ldb());
×
364
    j["ldc"] = serializer.expression(gemm_node.ldc());
×
365

366
    return j;
×
367
}
×
368

369
data_flow::LibraryNode& GEMMNodeSerializer::deserialize(
×
370
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
371
) {
372
    // Assertions for required fields
373
    assert(j.contains("element_id"));
×
374
    assert(j.contains("code"));
×
375
    assert(j.contains("debug_info"));
×
376

377
    auto code = j["code"].get<std::string>();
×
378
    if (code != LibraryNodeType_GEMM.value()) {
×
379
        throw std::runtime_error("Invalid library node code");
×
380
    }
381

382
    // Extract debug info using JSONSerializer
383
    sdfg::serializer::JSONSerializer serializer;
×
384
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
385

386
    auto precision = j.at("precision").get<BLAS_Precision>();
×
387
    auto layout = j.at("layout").get<BLAS_Layout>();
×
388
    auto trans_a = j.at("trans_a").get<BLAS_Transpose>();
×
389
    auto trans_b = j.at("trans_b").get<BLAS_Transpose>();
×
390
    auto m = symbolic::parse(j.at("m"));
×
391
    auto n = symbolic::parse(j.at("n"));
×
392
    auto k = symbolic::parse(j.at("k"));
×
393
    auto lda = symbolic::parse(j.at("lda"));
×
394
    auto ldb = symbolic::parse(j.at("ldb"));
×
395
    auto ldc = symbolic::parse(j.at("ldc"));
×
396

397
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
398

399
    return builder.add_library_node<
×
400
        GEMMNode>(parent, debug_info, implementation_type, precision, layout, trans_a, trans_b, m, n, k, lda, ldb, ldc);
×
401
}
×
402

403
GEMMNodeDispatcher_BLAS::GEMMNodeDispatcher_BLAS(
×
404
    codegen::LanguageExtension& language_extension,
405
    const Function& function,
406
    const data_flow::DataFlowGraph& data_flow_graph,
407
    const GEMMNode& node
408
)
409
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
410

411
void GEMMNodeDispatcher_BLAS::dispatch_code(
×
412
    codegen::PrettyPrinter& stream,
413
    codegen::PrettyPrinter& globals_stream,
414
    codegen::CodeSnippetFactory& library_snippet_factory
415
) {
416
    stream << "{" << std::endl;
×
417
    stream.setIndent(stream.indent() + 4);
×
418

419
    auto& gemm_node = static_cast<const GEMMNode&>(this->node_);
×
420

421
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
422
    switch (gemm_node.precision()) {
×
423
        case BLAS_Precision::h:
424
            base_type = types::Scalar(types::PrimitiveType::Half);
×
425
            break;
×
426
        case BLAS_Precision::s:
427
            base_type = types::Scalar(types::PrimitiveType::Float);
×
428
            break;
×
429
        case BLAS_Precision::d:
430
            base_type = types::Scalar(types::PrimitiveType::Double);
×
431
            break;
×
432
        default:
433
            throw std::runtime_error("Invalid BLAS_Precision value");
×
434
    }
435

436
    stream << "cblas_" << BLAS_Precision_to_string(gemm_node.precision()) << "gemm(";
×
437
    stream.setIndent(stream.indent() + 4);
×
438
    stream << BLAS_Layout_to_string(gemm_node.layout());
×
439
    stream << ", ";
×
440
    stream << BLAS_Transpose_to_string(gemm_node.trans_a());
×
441
    stream << ", ";
×
442
    stream << BLAS_Transpose_to_string(gemm_node.trans_b());
×
443
    stream << ", ";
×
444
    stream << this->language_extension_.expression(gemm_node.m());
×
445
    stream << ", ";
×
446
    stream << this->language_extension_.expression(gemm_node.n());
×
447
    stream << ", ";
×
448
    stream << this->language_extension_.expression(gemm_node.k());
×
449
    stream << ", ";
×
450
    stream << "alpha";
×
451
    stream << ", ";
×
452
    stream << "A";
×
453
    stream << ", ";
×
454
    stream << this->language_extension_.expression(gemm_node.lda());
×
455
    stream << ", ";
×
456
    stream << "B";
×
457
    stream << ", ";
×
458
    stream << this->language_extension_.expression(gemm_node.ldb());
×
459
    stream << ", ";
×
460
    stream << "beta";
×
461
    stream << ", ";
×
462
    stream << "C";
×
463
    stream << ", ";
×
464
    stream << this->language_extension_.expression(gemm_node.ldc());
×
465

466
    stream.setIndent(stream.indent() - 4);
×
467
    stream << ");" << std::endl;
×
468

469
    stream.setIndent(stream.indent() - 4);
×
470
    stream << "}" << std::endl;
×
471
}
×
472

473
GEMMNodeDispatcher_CUBLASWithTransfers::GEMMNodeDispatcher_CUBLASWithTransfers(
×
474
    codegen::LanguageExtension& language_extension,
475
    const Function& function,
476
    const data_flow::DataFlowGraph& data_flow_graph,
477
    const GEMMNode& node
478
)
479
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
480

481
void GEMMNodeDispatcher_CUBLASWithTransfers::dispatch_code(
×
482
    codegen::PrettyPrinter& stream,
483
    codegen::PrettyPrinter& globals_stream,
484
    codegen::CodeSnippetFactory& library_snippet_factory
485
) {
486
    throw std::runtime_error("GEMMNodeDispatcher_CUBLAS not implemented");
×
487
}
×
488

489
GEMMNodeDispatcher_CUBLASWithoutTransfers::GEMMNodeDispatcher_CUBLASWithoutTransfers(
×
490
    codegen::LanguageExtension& language_extension,
491
    const Function& function,
492
    const data_flow::DataFlowGraph& data_flow_graph,
493
    const GEMMNode& node
494
)
495
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
496

497
void GEMMNodeDispatcher_CUBLASWithoutTransfers::dispatch_code(
×
498
    codegen::PrettyPrinter& stream,
499
    codegen::PrettyPrinter& globals_stream,
500
    codegen::CodeSnippetFactory& library_snippet_factory
501
) {
502
    throw std::runtime_error("GEMMNodeDispatcher_CUBLAS not implemented");
×
503
}
×
504

505
} // namespace blas
506
} // namespace math
507
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc