• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 16468861557

23 Jul 2025 11:03AM UTC coverage: 64.967% (-1.0%) from 66.011%
16468861557

Pull #156

github

web-flow
Merge 054d54480 into 4c085404b
Pull Request #156: adds draft for GEMM node

29 of 272 new or added lines in 14 files covered. (10.66%)

4 existing lines in 2 files now uncovered.

8360 of 12868 relevant lines covered (64.97%)

131.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/data_flow/library_nodes/math/blas/gemm.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/gemm.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
#include "sdfg/analysis/scope_analysis.h"
7

8
namespace sdfg {
9
namespace math {
10
namespace blas {
11

NEW
12
GEMMNode::GEMMNode(
×
13
    size_t element_id,
14
    const DebugInfo& debug_info,
15
    const graph::Vertex vertex,
16
    data_flow::DataFlowGraph& parent,
17
    const data_flow::ImplementationType& implementation_type,
18
    const BLAS_Precision& precision,
19
    const BLAS_Layout& layout,
20
    const BLAS_Transpose& trans_a,
21
    const BLAS_Transpose& trans_b,
22
    symbolic::Expression m,
23
    symbolic::Expression n,
24
    symbolic::Expression k,
25
    symbolic::Expression lda,
26
    symbolic::Expression ldb,
27
    symbolic::Expression ldc,
28
    const std::string& alpha,
29
    const std::string& beta
30
)
NEW
31
    : MathNode(element_id, debug_info, vertex, parent, LibraryNodeType_GEMM, {"C"}, {"A", "B", "C"}, implementation_type),
×
NEW
32
      precision_(precision), layout_(layout), trans_a_(trans_a), trans_b_(trans_b), m_(m), n_(n), k_(k), lda_(lda),
×
NEW
33
      ldb_(ldb), ldc_(ldc), alpha_(alpha), beta_(beta) {
×
NEW
34
    if (alpha.empty()) {
×
NEW
35
        this->inputs_.push_back("alpha");
×
NEW
36
    }
×
NEW
37
    if (beta.empty()) {
×
NEW
38
        this->inputs_.push_back("beta");
×
NEW
39
    }
×
NEW
40
}
×
41

NEW
42
BLAS_Precision GEMMNode::precision() const { return this->precision_; };
×
43

NEW
44
BLAS_Layout GEMMNode::layout() const { return this->layout_; };
×
45

NEW
46
BLAS_Transpose GEMMNode::trans_a() const { return this->trans_a_; };
×
47

NEW
48
BLAS_Transpose GEMMNode::trans_b() const { return this->trans_b_; };
×
49

NEW
50
symbolic::Expression GEMMNode::m() const { return this->m_; };
×
51

NEW
52
symbolic::Expression GEMMNode::n() const { return this->n_; };
×
53

NEW
54
symbolic::Expression GEMMNode::k() const { return this->k_; };
×
55

NEW
56
symbolic::Expression GEMMNode::lda() const { return this->lda_; };
×
57

NEW
58
symbolic::Expression GEMMNode::ldb() const { return this->ldb_; };
×
59

NEW
60
symbolic::Expression GEMMNode::ldc() const { return this->ldc_; };
×
61

NEW
62
std::string GEMMNode::alpha() const { return this->alpha_; };
×
63

NEW
64
std::string GEMMNode::beta() const { return this->beta_; };
×
65

NEW
66
void GEMMNode::validate(const Function& function) const {
×
NEW
67
    auto& graph = this->get_parent();
×
68

NEW
69
    if (graph.in_degree(*this) != this->inputs_.size()) {
×
NEW
70
        throw InvalidSDFGException("GEMMNode must have " + std::to_string(this->inputs_.size()) + " inputs");
×
71
    }
NEW
72
    if (graph.out_degree(*this) != 1) {
×
NEW
73
        throw InvalidSDFGException("GEMMNode must have 1 output");
×
74
    }
75

76
    // Check if all inputs are connected A, B, C, (alpha), (beta)
NEW
77
    std::unordered_map<std::string, const data_flow::Memlet*> memlets;
×
NEW
78
    for (auto& input : this->inputs_) {
×
NEW
79
        bool found = false;
×
NEW
80
        for (auto& iedge : graph.in_edges(*this)) {
×
NEW
81
            if (iedge.dst_conn() == input) {
×
NEW
82
                found = true;
×
NEW
83
                memlets[input] = &iedge;
×
NEW
84
                break;
×
85
            }
86
        }
NEW
87
        if (!found) {
×
NEW
88
            throw InvalidSDFGException("GEMMNode input " + input + " not found");
×
89
        }
90
    }
91

92
    // Check if output is connected to C
NEW
93
    auto& oedge = *graph.out_edges(*this).begin();
×
NEW
94
    if (oedge.src_conn() != this->outputs_.at(0)) {
×
NEW
95
        throw InvalidSDFGException("GEMMNode output " + this->outputs_.at(0) + " not found");
×
96
    }
97

98
    // Check dimensions of A, B, C
NEW
99
    auto& a_memlet = memlets.at("A");
×
NEW
100
    auto& a_subset_begin = a_memlet->begin_subset();
×
NEW
101
    auto& a_subset_end = a_memlet->end_subset();
×
NEW
102
    if (a_subset_begin.size() != 1) {
×
NEW
103
        throw InvalidSDFGException("GEMMNode input A must have 1 dimensions");
×
104
    }
NEW
105
    data_flow::Subset a_dims;
×
NEW
106
    for (size_t i = 0; i < a_subset_begin.size(); i++) {
×
NEW
107
        a_dims.push_back(symbolic::sub(a_subset_end[i], a_subset_begin[i]));
×
NEW
108
    }
×
109

NEW
110
    auto& b_memlet = memlets.at("B");
×
NEW
111
    auto& b_subset_begin = b_memlet->begin_subset();
×
NEW
112
    auto& b_subset_end = b_memlet->end_subset();
×
NEW
113
    if (b_subset_begin.size() != 1) {
×
NEW
114
        throw InvalidSDFGException("GEMMNode input B must have 1 dimensions");
×
115
    }
NEW
116
    data_flow::Subset b_dims;
×
NEW
117
    for (size_t i = 0; i < b_subset_begin.size(); i++) {
×
NEW
118
        b_dims.push_back(symbolic::sub(b_subset_end[i], b_subset_begin[i]));
×
NEW
119
    }
×
120

NEW
121
    auto& c_memlet = memlets.at("C");
×
NEW
122
    auto& c_subset_begin = c_memlet->begin_subset();
×
NEW
123
    auto& c_subset_end = c_memlet->end_subset();
×
NEW
124
    if (c_subset_begin.size() != 1) {
×
NEW
125
        throw InvalidSDFGException("GEMMNode input C must have 1 dimensions");
×
126
    }
NEW
127
    data_flow::Subset c_dims;
×
NEW
128
    for (size_t i = 0; i < c_subset_begin.size(); i++) {
×
NEW
129
        c_dims.push_back(symbolic::sub(c_subset_end[i], c_subset_begin[i]));
×
NEW
130
    }
×
131

132
    // TODO: Check if dimensions of A, B, C are valid
NEW
133
}
×
134

NEW
135
bool GEMMNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
NEW
136
    auto& sdfg = builder.subject();
×
NEW
137
    auto& dataflow = this->get_parent();
×
NEW
138
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
139

140
    // TODO: Expand GEMM node
141

NEW
142
    return false;
×
143
}
144

145
std::unique_ptr<data_flow::DataFlowNode> GEMMNode::
NEW
146
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
147
    auto node_clone = std::unique_ptr<GEMMNode>(new GEMMNode(
×
NEW
148
        element_id,
×
NEW
149
        this->debug_info(),
×
NEW
150
        vertex,
×
NEW
151
        parent,
×
NEW
152
        this->implementation_type_,
×
NEW
153
        this->precision_,
×
NEW
154
        this->layout_,
×
NEW
155
        this->trans_a_,
×
NEW
156
        this->trans_b_,
×
NEW
157
        this->m_,
×
NEW
158
        this->n_,
×
NEW
159
        this->k_,
×
NEW
160
        this->lda_,
×
NEW
161
        this->ldb_,
×
NEW
162
        this->ldc_,
×
NEW
163
        this->alpha_,
×
NEW
164
        this->beta_
×
165
    ));
NEW
166
    return std::move(node_clone);
×
NEW
167
}
×
168

NEW
169
nlohmann::json GEMMNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
NEW
170
    const GEMMNode& gemm_node = static_cast<const GEMMNode&>(library_node);
×
NEW
171
    nlohmann::json j;
×
172

NEW
173
    serializer::JSONSerializer serializer;
×
NEW
174
    j["code"] = gemm_node.code().value();
×
NEW
175
    j["precision"] = gemm_node.precision();
×
NEW
176
    j["layout"] = gemm_node.layout();
×
NEW
177
    j["trans_a"] = gemm_node.trans_a();
×
NEW
178
    j["trans_b"] = gemm_node.trans_b();
×
NEW
179
    j["m"] = serializer.expression(gemm_node.m());
×
NEW
180
    j["n"] = serializer.expression(gemm_node.n());
×
NEW
181
    j["k"] = serializer.expression(gemm_node.k());
×
NEW
182
    j["lda"] = serializer.expression(gemm_node.lda());
×
NEW
183
    j["ldb"] = serializer.expression(gemm_node.ldb());
×
NEW
184
    j["ldc"] = serializer.expression(gemm_node.ldc());
×
NEW
185
    j["alpha"] = gemm_node.alpha();
×
NEW
186
    j["beta"] = gemm_node.beta();
×
187

NEW
188
    return j;
×
NEW
189
}
×
190

NEW
191
data_flow::LibraryNode& GEMMNodeSerializer::deserialize(
×
192
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
193
) {
194
    // Assertions for required fields
NEW
195
    assert(j.contains("element_id"));
×
NEW
196
    assert(j.contains("code"));
×
NEW
197
    assert(j.contains("debug_info"));
×
198

NEW
199
    auto code = j["code"].get<std::string>();
×
NEW
200
    if (code != LibraryNodeType_GEMM.value()) {
×
NEW
201
        throw std::runtime_error("Invalid library node code");
×
202
    }
203

204
    // Extract debug info using JSONSerializer
NEW
205
    sdfg::serializer::JSONSerializer serializer;
×
NEW
206
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
207

NEW
208
    auto precision = j.at("precision").get<BLAS_Precision>();
×
NEW
209
    auto layout = j.at("layout").get<BLAS_Layout>();
×
NEW
210
    auto trans_a = j.at("trans_a").get<BLAS_Transpose>();
×
NEW
211
    auto trans_b = j.at("trans_b").get<BLAS_Transpose>();
×
NEW
212
    auto m = SymEngine::Expression(j.at("m"));
×
NEW
213
    auto n = SymEngine::Expression(j.at("n"));
×
NEW
214
    auto k = SymEngine::Expression(j.at("k"));
×
NEW
215
    auto lda = SymEngine::Expression(j.at("lda"));
×
NEW
216
    auto ldb = SymEngine::Expression(j.at("ldb"));
×
NEW
217
    auto ldc = SymEngine::Expression(j.at("ldc"));
×
NEW
218
    auto alpha = j.at("alpha").get<std::string>();
×
NEW
219
    auto beta = j.at("beta").get<std::string>();
×
220

NEW
221
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
222

NEW
223
    return builder.add_library_node<GEMMNode>(
×
NEW
224
        parent, debug_info, implementation_type, precision, layout, trans_a, trans_b, m, n, k, lda, ldb, ldc, alpha, beta
×
225
    );
NEW
226
}
×
227

NEW
228
GEMMNodeDispatcher_BLAS::GEMMNodeDispatcher_BLAS(
×
229
    codegen::LanguageExtension& language_extension,
230
    const Function& function,
231
    const data_flow::DataFlowGraph& data_flow_graph,
232
    const GEMMNode& node
233
)
NEW
234
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
235

NEW
236
void GEMMNodeDispatcher_BLAS::dispatch(codegen::PrettyPrinter& stream) {
×
NEW
237
    stream << "{" << std::endl;
×
NEW
238
    stream.setIndent(stream.indent() + 4);
×
239

NEW
240
    auto& gemm_node = static_cast<const GEMMNode&>(this->node_);
×
241

NEW
242
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
NEW
243
    switch (gemm_node.precision()) {
×
244
        case BLAS_Precision::h:
NEW
245
            base_type = types::Scalar(types::PrimitiveType::Half);
×
NEW
246
            break;
×
247
        case BLAS_Precision::s:
NEW
248
            base_type = types::Scalar(types::PrimitiveType::Float);
×
NEW
249
            break;
×
250
        case BLAS_Precision::d:
NEW
251
            base_type = types::Scalar(types::PrimitiveType::Double);
×
NEW
252
            break;
×
253
        default:
NEW
254
            throw std::runtime_error("Invalid BLAS_Precision value");
×
255
    }
256

NEW
257
    auto& graph = this->node_.get_parent();
×
NEW
258
    for (auto& iedge : graph.in_edges(this->node_)) {
×
NEW
259
        auto& access_node = static_cast<const data_flow::AccessNode&>(iedge.src());
×
NEW
260
        std::string name = access_node.data();
×
NEW
261
        auto& type = this->function_.type(name);
×
262

NEW
263
        stream << this->language_extension_.declaration(iedge.dst_conn(), type);
×
NEW
264
        stream << " = " << name << ";" << std::endl;
×
NEW
265
    }
×
266

NEW
267
    if (std::find(gemm_node.inputs().begin(), gemm_node.inputs().end(), "alpha") == gemm_node.inputs().end()) {
×
NEW
268
        stream << this->language_extension_.declaration("alpha", base_type);
×
NEW
269
        stream << " = " << gemm_node.alpha() << ";" << std::endl;
×
NEW
270
    }
×
NEW
271
    if (std::find(gemm_node.inputs().begin(), gemm_node.inputs().end(), "beta") == gemm_node.inputs().end()) {
×
NEW
272
        stream << this->language_extension_.declaration("beta", base_type);
×
NEW
273
        stream << " = " << gemm_node.beta() << ";" << std::endl;
×
NEW
274
    }
×
275

NEW
276
    stream << "cblas_" << BLAS_Precision_to_string(gemm_node.precision()) << "gemm(";
×
NEW
277
    stream.setIndent(stream.indent() + 4);
×
NEW
278
    stream << BLAS_Layout_to_string(gemm_node.layout());
×
NEW
279
    stream << ", ";
×
NEW
280
    stream << BLAS_Transpose_to_string(gemm_node.trans_a());
×
NEW
281
    stream << ", ";
×
NEW
282
    stream << BLAS_Transpose_to_string(gemm_node.trans_b());
×
NEW
283
    stream << ", ";
×
NEW
284
    stream << this->language_extension_.expression(gemm_node.m());
×
NEW
285
    stream << ", ";
×
NEW
286
    stream << this->language_extension_.expression(gemm_node.n());
×
NEW
287
    stream << ", ";
×
NEW
288
    stream << this->language_extension_.expression(gemm_node.k());
×
NEW
289
    stream << ", ";
×
NEW
290
    stream << "alpha";
×
NEW
291
    stream << ", ";
×
NEW
292
    stream << "A";
×
NEW
293
    stream << ", ";
×
NEW
294
    stream << this->language_extension_.expression(gemm_node.lda());
×
NEW
295
    stream << ", ";
×
NEW
296
    stream << "B";
×
NEW
297
    stream << ", ";
×
NEW
298
    stream << this->language_extension_.expression(gemm_node.ldb());
×
NEW
299
    stream << ", ";
×
NEW
300
    stream << "beta";
×
NEW
301
    stream << ", ";
×
NEW
302
    stream << "C";
×
NEW
303
    stream << ", ";
×
NEW
304
    stream << this->language_extension_.expression(gemm_node.ldc());
×
305

NEW
306
    stream.setIndent(stream.indent() - 4);
×
NEW
307
    stream << ");" << std::endl;
×
308

NEW
309
    stream.setIndent(stream.indent() - 4);
×
NEW
310
    stream << "}" << std::endl;
×
NEW
311
}
×
312

NEW
313
GEMMNodeDispatcher_CUBLAS::GEMMNodeDispatcher_CUBLAS(
×
314
    codegen::LanguageExtension& language_extension,
315
    const Function& function,
316
    const data_flow::DataFlowGraph& data_flow_graph,
317
    const GEMMNode& node
318
)
NEW
319
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
320

NEW
321
void GEMMNodeDispatcher_CUBLAS::dispatch(codegen::PrettyPrinter& stream) {
×
NEW
322
    throw std::runtime_error("GEMMNodeDispatcher_CUBLAS not implemented");
×
NEW
323
}
×
324

325
} // namespace blas
326
} // namespace math
327
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc