• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 16454790608

22 Jul 2025 08:26PM UTC coverage: 65.244% (-0.8%) from 66.011%
16454790608

Pull #156

github

web-flow
Merge 8b3fea29a into 4c085404b
Pull Request #156: adds draft for GEMM node

10 of 165 new or added lines in 5 files covered. (6.06%)

3 existing lines in 1 file now uncovered.

8331 of 12769 relevant lines covered (65.24%)

132.3 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/data_flow/library_nodes/math/blas/gemm.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/gemm.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
#include "sdfg/analysis/scope_analysis.h"
7

8
namespace sdfg {
9
namespace math {
10
namespace blas {
11

NEW
12
GEMMNode::GEMMNode(
×
13
    size_t element_id,
14
    const DebugInfo& debug_info,
15
    const graph::Vertex vertex,
16
    data_flow::DataFlowGraph& parent,
17
    const BLAS_Precision& precision,
18
    const BLAS_Layout& layout,
19
    const BLAS_Transpose& trans_a,
20
    const BLAS_Transpose& trans_b,
21
    symbolic::Expression m,
22
    symbolic::Expression n,
23
    symbolic::Expression k,
24
    symbolic::Expression lda,
25
    symbolic::Expression ldb,
26
    symbolic::Expression ldc,
27
    const std::string& alpha,
28
    const std::string& beta
29
)
NEW
30
    : MathNode(element_id, debug_info, vertex, parent, LibraryNodeType_GEMM, {"C"}, {"A", "B", "C"}),
×
NEW
31
      precision_(precision), layout_(layout), trans_a_(trans_a), trans_b_(trans_b), m_(m), n_(n), k_(k), lda_(lda),
×
NEW
32
      ldb_(ldb), ldc_(ldc), alpha_(alpha), beta_(beta) {
×
NEW
33
    if (alpha.empty()) {
×
NEW
34
        this->inputs_.push_back("alpha");
×
NEW
35
    }
×
NEW
36
    if (beta.empty()) {
×
NEW
37
        this->inputs_.push_back("beta");
×
NEW
38
    }
×
NEW
39
}
×
40

NEW
41
BLAS_Precision GEMMNode::precision() const { return this->precision_; };
×
42

NEW
43
BLAS_Layout GEMMNode::layout() const { return this->layout_; };
×
44

NEW
45
BLAS_Transpose GEMMNode::trans_a() const { return this->trans_a_; };
×
46

NEW
47
BLAS_Transpose GEMMNode::trans_b() const { return this->trans_b_; };
×
48

NEW
49
symbolic::Expression GEMMNode::m() const { return this->m_; };
×
50

NEW
51
symbolic::Expression GEMMNode::n() const { return this->n_; };
×
52

NEW
53
symbolic::Expression GEMMNode::k() const { return this->k_; };
×
54

NEW
55
symbolic::Expression GEMMNode::lda() const { return this->lda_; };
×
56

NEW
57
symbolic::Expression GEMMNode::ldb() const { return this->ldb_; };
×
58

NEW
59
symbolic::Expression GEMMNode::ldc() const { return this->ldc_; };
×
60

NEW
61
std::string GEMMNode::alpha() const { return this->alpha_; };
×
62

NEW
63
std::string GEMMNode::beta() const { return this->beta_; };
×
64

NEW
65
void GEMMNode::validate(const Function& function) const {
×
NEW
66
    auto& graph = this->get_parent();
×
67

NEW
68
    if (graph.in_degree(*this) != this->inputs_.size()) {
×
NEW
69
        throw InvalidSDFGException("GEMMNode must have " + std::to_string(this->inputs_.size()) + " inputs");
×
70
    }
NEW
71
    if (graph.out_degree(*this) != 1) {
×
NEW
72
        throw InvalidSDFGException("GEMMNode must have 1 output");
×
73
    }
74

75
    // Check if all inputs are connected A, B, C, (alpha), (beta)
NEW
76
    std::unordered_map<std::string, const data_flow::Memlet*> memlets;
×
NEW
77
    for (auto& input : this->inputs_) {
×
NEW
78
        bool found = false;
×
NEW
79
        for (auto& iedge : graph.in_edges(*this)) {
×
NEW
80
            if (iedge.dst_conn() == input) {
×
NEW
81
                found = true;
×
NEW
82
                memlets[input] = &iedge;
×
NEW
83
                break;
×
84
            }
85
        }
NEW
86
        if (!found) {
×
NEW
87
            throw InvalidSDFGException("GEMMNode input " + input + " not found");
×
88
        }
89
    }
90

91
    // Check if output is connected to C
NEW
92
    auto& oedge = *graph.out_edges(*this).begin();
×
NEW
93
    if (oedge.src_conn() != this->outputs_.at(0)) {
×
NEW
94
        throw InvalidSDFGException("GEMMNode output " + this->outputs_.at(0) + " not found");
×
95
    }
96

97
    // Check dimensions of A, B, C
NEW
98
    auto& a_memlet = memlets.at("A");
×
NEW
99
    auto& a_subset_begin = a_memlet->begin_subset();
×
NEW
100
    auto& a_subset_end = a_memlet->end_subset();
×
NEW
101
    if (a_subset_begin.size() != 1) {
×
NEW
102
        throw InvalidSDFGException("GEMMNode input A must have 1 dimensions");
×
103
    }
NEW
104
    data_flow::Subset a_dims;
×
NEW
105
    for (size_t i = 0; i < a_subset_begin.size(); i++) {
×
NEW
106
        a_dims.push_back(symbolic::sub(a_subset_end[i], a_subset_begin[i]));
×
NEW
107
    }
×
108

NEW
109
    auto& b_memlet = memlets.at("B");
×
NEW
110
    auto& b_subset_begin = b_memlet->begin_subset();
×
NEW
111
    auto& b_subset_end = b_memlet->end_subset();
×
NEW
112
    if (b_subset_begin.size() != 1) {
×
NEW
113
        throw InvalidSDFGException("GEMMNode input B must have 1 dimensions");
×
114
    }
NEW
115
    data_flow::Subset b_dims;
×
NEW
116
    for (size_t i = 0; i < b_subset_begin.size(); i++) {
×
NEW
117
        b_dims.push_back(symbolic::sub(b_subset_end[i], b_subset_begin[i]));
×
NEW
118
    }
×
119

NEW
120
    auto& c_memlet = memlets.at("C");
×
NEW
121
    auto& c_subset_begin = c_memlet->begin_subset();
×
NEW
122
    auto& c_subset_end = c_memlet->end_subset();
×
NEW
123
    if (c_subset_begin.size() != 1) {
×
NEW
124
        throw InvalidSDFGException("GEMMNode input C must have 1 dimensions");
×
125
    }
NEW
126
    data_flow::Subset c_dims;
×
NEW
127
    for (size_t i = 0; i < c_subset_begin.size(); i++) {
×
NEW
128
        c_dims.push_back(symbolic::sub(c_subset_end[i], c_subset_begin[i]));
×
NEW
129
    }
×
130

131
    // TODO: Check if dimensions of A, B, C are valid
NEW
132
}
×
133

NEW
134
bool GEMMNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
NEW
135
    auto& sdfg = builder.subject();
×
NEW
136
    auto& dataflow = this->get_parent();
×
NEW
137
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
138

139
    // TODO: Expand GEMM node
140

NEW
141
    return false;
×
142
}
143

144
std::unique_ptr<data_flow::DataFlowNode> GEMMNode::
NEW
145
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
146
    return std::unique_ptr<data_flow::DataFlowNode>(new GEMMNode(
×
NEW
147
        element_id,
×
NEW
148
        this->debug_info(),
×
NEW
149
        vertex,
×
NEW
150
        parent,
×
NEW
151
        this->precision_,
×
NEW
152
        this->layout_,
×
NEW
153
        this->trans_a_,
×
NEW
154
        this->trans_b_,
×
NEW
155
        this->m_,
×
NEW
156
        this->n_,
×
NEW
157
        this->k_,
×
NEW
158
        this->lda_,
×
NEW
159
        this->ldb_,
×
NEW
160
        this->ldc_,
×
NEW
161
        this->alpha_,
×
NEW
162
        this->beta_
×
163
    ));
NEW
164
}
×
165

NEW
166
nlohmann::json GEMMNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
NEW
167
    const GEMMNode& gemm_node = static_cast<const GEMMNode&>(library_node);
×
NEW
168
    nlohmann::json j;
×
169

NEW
170
    serializer::JSONSerializer serializer;
×
NEW
171
    j["code"] = gemm_node.code().value();
×
NEW
172
    j["precision"] = gemm_node.precision();
×
NEW
173
    j["layout"] = gemm_node.layout();
×
NEW
174
    j["trans_a"] = gemm_node.trans_a();
×
NEW
175
    j["trans_b"] = gemm_node.trans_b();
×
NEW
176
    j["m"] = serializer.expression(gemm_node.m());
×
NEW
177
    j["n"] = serializer.expression(gemm_node.n());
×
NEW
178
    j["k"] = serializer.expression(gemm_node.k());
×
NEW
179
    j["lda"] = serializer.expression(gemm_node.lda());
×
NEW
180
    j["ldb"] = serializer.expression(gemm_node.ldb());
×
NEW
181
    j["ldc"] = serializer.expression(gemm_node.ldc());
×
NEW
182
    j["alpha"] = gemm_node.alpha();
×
NEW
183
    j["beta"] = gemm_node.beta();
×
184

NEW
185
    return j;
×
NEW
186
}
×
187

NEW
188
data_flow::LibraryNode& GEMMNodeSerializer::deserialize(
×
189
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
190
) {
191
    // Assertions for required fields
NEW
192
    assert(j.contains("element_id"));
×
NEW
193
    assert(j.contains("code"));
×
NEW
194
    assert(j.contains("debug_info"));
×
195

NEW
196
    auto code = j["code"].get<std::string>();
×
NEW
197
    if (code != LibraryNodeType_GEMM.value()) {
×
NEW
198
        throw std::runtime_error("Invalid library node code");
×
199
    }
200

201
    // Extract debug info using JSONSerializer
NEW
202
    sdfg::serializer::JSONSerializer serializer;
×
NEW
203
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
204

NEW
205
    auto precision = j.at("precision").get<BLAS_Precision>();
×
NEW
206
    auto layout = j.at("layout").get<BLAS_Layout>();
×
NEW
207
    auto trans_a = j.at("trans_a").get<BLAS_Transpose>();
×
NEW
208
    auto trans_b = j.at("trans_b").get<BLAS_Transpose>();
×
NEW
209
    auto m = SymEngine::Expression(j.at("m"));
×
NEW
210
    auto n = SymEngine::Expression(j.at("n"));
×
NEW
211
    auto k = SymEngine::Expression(j.at("k"));
×
NEW
212
    auto lda = SymEngine::Expression(j.at("lda"));
×
NEW
213
    auto ldb = SymEngine::Expression(j.at("ldb"));
×
NEW
214
    auto ldc = SymEngine::Expression(j.at("ldc"));
×
NEW
215
    auto alpha = j.at("alpha").get<std::string>();
×
NEW
216
    auto beta = j.at("beta").get<std::string>();
×
217

NEW
218
    return builder.add_library_node<
×
NEW
219
        GEMMNode>(parent, debug_info, precision, layout, trans_a, trans_b, m, n, k, lda, ldb, ldc, alpha, beta);
×
NEW
220
}
×
221

NEW
222
GEMMNodeDispatcher_BLAS::GEMMNodeDispatcher_BLAS(
×
223
    codegen::LanguageExtension& language_extension,
224
    const Function& function,
225
    const data_flow::DataFlowGraph& data_flow_graph,
226
    const GEMMNode& node
227
)
NEW
228
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
229

NEW
230
void GEMMNodeDispatcher_BLAS::dispatch(codegen::PrettyPrinter& stream) {
×
NEW
231
    throw std::runtime_error("GEMMNodeDispatcher_BLAS not implemented");
×
NEW
232
}
×
233

NEW
234
GEMMNodeDispatcher_CUBLAS::GEMMNodeDispatcher_CUBLAS(
×
235
    codegen::LanguageExtension& language_extension,
236
    const Function& function,
237
    const data_flow::DataFlowGraph& data_flow_graph,
238
    const GEMMNode& node
239
)
NEW
240
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
241

NEW
242
void GEMMNodeDispatcher_CUBLAS::dispatch(codegen::PrettyPrinter& stream) {
×
NEW
243
    throw std::runtime_error("GEMMNodeDispatcher_CUBLAS not implemented");
×
NEW
244
}
×
245

246
} // namespace blas
247
} // namespace math
248
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc