• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 16454139884

22 Jul 2025 07:52PM UTC coverage: 65.338% (-0.7%) from 66.011%
16454139884

Pull #156

github

web-flow
Merge cf357586f into 4c085404b
Pull Request #156: adds draft for GEMM node

0 of 130 new or added lines in 1 file covered. (0.0%)

110 existing lines in 10 files now uncovered.

8326 of 12743 relevant lines covered (65.34%)

132.57 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/data_flow/library_nodes/math/blas/gemm.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/gemm.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
#include "sdfg/analysis/scope_analysis.h"
7

8
namespace sdfg {
9
namespace math {
10
namespace blas {
11

NEW
12
GEMMNode::GEMMNode(
×
13
    size_t element_id,
14
    const DebugInfo& debug_info,
15
    const graph::Vertex vertex,
16
    data_flow::DataFlowGraph& parent,
17
    const BLAS_Precision& precision,
18
    const BLAS_Layout& layout,
19
    const BLAS_Transpose& trans_a,
20
    const BLAS_Transpose& trans_b,
21
    symbolic::Expression m,
22
    symbolic::Expression n,
23
    symbolic::Expression k,
24
    symbolic::Expression lda,
25
    symbolic::Expression ldb,
26
    symbolic::Expression ldc,
27
    const std::string& alpha,
28
    const std::string& beta
29
)
NEW
30
    : MathNode(element_id, debug_info, vertex, parent, LibraryNodeType_GEMM, {"C"}, {"A", "B", "C"}),
×
NEW
31
      precision_(precision), layout_(layout), trans_a_(trans_a), trans_b_(trans_b), m_(m), n_(n), k_(k), lda_(lda),
×
NEW
32
      ldb_(ldb), ldc_(ldc), alpha_(alpha), beta_(beta) {
×
NEW
33
    if (alpha.empty()) {
×
NEW
34
        this->inputs_.push_back("alpha");
×
NEW
35
    }
×
NEW
36
    if (beta.empty()) {
×
NEW
37
        this->inputs_.push_back("beta");
×
NEW
38
    }
×
NEW
39
}
×
40

NEW
41
void GEMMNode::validate(const Function& function) const {
×
NEW
42
    auto& graph = this->get_parent();
×
43

NEW
44
    if (graph.in_degree(*this) != this->inputs_.size()) {
×
NEW
45
        throw InvalidSDFGException("GEMMNode must have " + std::to_string(this->inputs_.size()) + " inputs");
×
46
    }
NEW
47
    if (graph.out_degree(*this) != 1) {
×
NEW
48
        throw InvalidSDFGException("GEMMNode must have 1 output");
×
49
    }
50

51
    // Check if all inputs are connected A, B, C, (alpha), (beta)
NEW
52
    std::unordered_map<std::string, const data_flow::Memlet*> memlets;
×
NEW
53
    for (auto& input : this->inputs_) {
×
NEW
54
        bool found = false;
×
NEW
55
        for (auto& iedge : graph.in_edges(*this)) {
×
NEW
56
            if (iedge.dst_conn() == input) {
×
NEW
57
                found = true;
×
NEW
58
                memlets[input] = &iedge;
×
NEW
59
                break;
×
60
            }
61
        }
NEW
62
        if (!found) {
×
NEW
63
            throw InvalidSDFGException("GEMMNode input " + input + " not found");
×
64
        }
65
    }
66

67
    // Check if output is connected to C
NEW
68
    auto& oedge = *graph.out_edges(*this).begin();
×
NEW
69
    if (oedge.src_conn() != this->outputs_.at(0)) {
×
NEW
70
        throw InvalidSDFGException("GEMMNode output " + this->outputs_.at(0) + " not found");
×
71
    }
72

73
    // Check dimensions of A, B, C
NEW
74
    auto& a_memlet = memlets.at("A");
×
NEW
75
    auto& a_subset_begin = a_memlet->begin_subset();
×
NEW
76
    auto& a_subset_end = a_memlet->end_subset();
×
NEW
77
    if (a_subset_begin.size() != 1) {
×
NEW
78
        throw InvalidSDFGException("GEMMNode input A must have 1 dimensions");
×
79
    }
NEW
80
    data_flow::Subset a_dims;
×
NEW
81
    for (size_t i = 0; i < a_subset_begin.size(); i++) {
×
NEW
82
        a_dims.push_back(symbolic::sub(a_subset_end[i], a_subset_begin[i]));
×
NEW
83
    }
×
84

NEW
85
    auto& b_memlet = memlets.at("B");
×
NEW
86
    auto& b_subset_begin = b_memlet->begin_subset();
×
NEW
87
    auto& b_subset_end = b_memlet->end_subset();
×
NEW
88
    if (b_subset_begin.size() != 1) {
×
NEW
89
        throw InvalidSDFGException("GEMMNode input B must have 1 dimensions");
×
90
    }
NEW
91
    data_flow::Subset b_dims;
×
NEW
92
    for (size_t i = 0; i < b_subset_begin.size(); i++) {
×
NEW
93
        b_dims.push_back(symbolic::sub(b_subset_end[i], b_subset_begin[i]));
×
NEW
94
    }
×
95

NEW
96
    auto& c_memlet = memlets.at("C");
×
NEW
97
    auto& c_subset_begin = c_memlet->begin_subset();
×
NEW
98
    auto& c_subset_end = c_memlet->end_subset();
×
NEW
99
    if (c_subset_begin.size() != 1) {
×
NEW
100
        throw InvalidSDFGException("GEMMNode input C must have 1 dimensions");
×
101
    }
NEW
102
    data_flow::Subset c_dims;
×
NEW
103
    for (size_t i = 0; i < c_subset_begin.size(); i++) {
×
NEW
104
        c_dims.push_back(symbolic::sub(c_subset_end[i], c_subset_begin[i]));
×
NEW
105
    }
×
106

107
    // TODO: Check if dimensions of A, B, C are valid
NEW
108
}
×
109

NEW
110
bool GEMMNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
NEW
111
    auto& sdfg = builder.subject();
×
NEW
112
    auto& dataflow = this->get_parent();
×
NEW
113
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
114

115
    // TODO: Expand GEMM node
116

NEW
117
    return false;
×
118
}
119

120
std::unique_ptr<data_flow::DataFlowNode> GEMMNode::
NEW
121
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
122
    return std::unique_ptr<data_flow::DataFlowNode>(new GEMMNode(
×
NEW
123
        element_id,
×
NEW
124
        this->debug_info(),
×
NEW
125
        vertex,
×
NEW
126
        parent,
×
NEW
127
        this->precision_,
×
NEW
128
        this->layout_,
×
NEW
129
        this->trans_a_,
×
NEW
130
        this->trans_b_,
×
NEW
131
        this->m_,
×
NEW
132
        this->n_,
×
NEW
133
        this->k_,
×
NEW
134
        this->lda_,
×
NEW
135
        this->ldb_,
×
NEW
136
        this->ldc_,
×
NEW
137
        this->alpha_,
×
NEW
138
        this->beta_
×
139
    ));
NEW
140
}
×
141

NEW
142
nlohmann::json GEMMNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
NEW
143
    const GEMMNode& gemm_node = static_cast<const GEMMNode&>(library_node);
×
NEW
144
    nlohmann::json j;
×
145

NEW
146
    serializer::JSONSerializer serializer;
×
NEW
147
    j["code"] = gemm_node.code().value();
×
NEW
148
    j["precision"] = gemm_node.precision();
×
NEW
149
    j["layout"] = gemm_node.layout();
×
NEW
150
    j["trans_a"] = gemm_node.trans_a();
×
NEW
151
    j["trans_b"] = gemm_node.trans_b();
×
NEW
152
    j["m"] = serializer.expression(gemm_node.m());
×
NEW
153
    j["n"] = serializer.expression(gemm_node.n());
×
NEW
154
    j["k"] = serializer.expression(gemm_node.k());
×
NEW
155
    j["lda"] = serializer.expression(gemm_node.lda());
×
NEW
156
    j["ldb"] = serializer.expression(gemm_node.ldb());
×
NEW
157
    j["ldc"] = serializer.expression(gemm_node.ldc());
×
NEW
158
    j["alpha"] = gemm_node.alpha();
×
NEW
159
    j["beta"] = gemm_node.beta();
×
160

NEW
161
    return j;
×
NEW
162
}
×
163

NEW
164
data_flow::LibraryNode& GEMMNodeSerializer::deserialize(
×
165
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
166
) {
167
    // Assertions for required fields
NEW
168
    assert(j.contains("element_id"));
×
NEW
169
    assert(j.contains("code"));
×
NEW
170
    assert(j.contains("debug_info"));
×
171

NEW
172
    auto code = j["code"].get<std::string>();
×
NEW
173
    if (code != LibraryNodeType_GEMM.value()) {
×
NEW
174
        throw std::runtime_error("Invalid library node code");
×
175
    }
176

177
    // Extract debug info using JSONSerializer
NEW
178
    sdfg::serializer::JSONSerializer serializer;
×
NEW
179
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
180

NEW
181
    auto precision = j.at("precision").get<BLAS_Precision>();
×
NEW
182
    auto layout = j.at("layout").get<BLAS_Layout>();
×
NEW
183
    auto trans_a = j.at("trans_a").get<BLAS_Transpose>();
×
NEW
184
    auto trans_b = j.at("trans_b").get<BLAS_Transpose>();
×
NEW
185
    auto m = SymEngine::Expression(j.at("m"));
×
NEW
186
    auto n = SymEngine::Expression(j.at("n"));
×
NEW
187
    auto k = SymEngine::Expression(j.at("k"));
×
NEW
188
    auto lda = SymEngine::Expression(j.at("lda"));
×
NEW
189
    auto ldb = SymEngine::Expression(j.at("ldb"));
×
NEW
190
    auto ldc = SymEngine::Expression(j.at("ldc"));
×
NEW
191
    auto alpha = j.at("alpha").get<std::string>();
×
NEW
192
    auto beta = j.at("beta").get<std::string>();
×
193

NEW
194
    return builder.add_library_node<
×
NEW
195
        GEMMNode>(parent, debug_info, precision, layout, trans_a, trans_b, m, n, k, lda, ldb, ldc, alpha, beta);
×
NEW
196
}
×
197

NEW
198
GEMMNodeDispatcher::GEMMNodeDispatcher(
×
199
    codegen::LanguageExtension& language_extension,
200
    const Function& function,
201
    const data_flow::DataFlowGraph& data_flow_graph,
202
    const GEMMNode& node
203
)
NEW
204
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
205

NEW
206
void GEMMNodeDispatcher::dispatch(codegen::PrettyPrinter& stream) {
×
NEW
207
    throw std::runtime_error("GEMMNode not implemented");
×
NEW
208
}
×
209

210

211
} // namespace blas
212
} // namespace math
213
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc