• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 17300165647

28 Aug 2025 03:14PM UTC coverage: 60.049% (+0.3%) from 59.781%
17300165647

Pull #210

github

web-flow
Merge f6109d03a into 18d34db1e
Pull Request #210: New debug info

377 of 593 new or added lines in 37 files covered. (63.58%)

15 existing lines in 8 files now uncovered.

9588 of 15967 relevant lines covered (60.05%)

114.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/data_flow/library_nodes/math/ml/gemm.cpp
1
#include "sdfg/data_flow/library_nodes/math/ml/gemm.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/analysis/scope_analysis.h"
5
#include "sdfg/builder/structured_sdfg_builder.h"
6

7
namespace sdfg {
8
namespace math {
9
namespace ml {
10

11
GemmNode::GemmNode(
×
12
    size_t element_id,
13
    const DebugInfoRegion &debug_info,
14
    const graph::Vertex vertex,
15
    data_flow::DataFlowGraph &parent,
16
    const std::string &alpha,
17
    const std::string &beta,
18
    bool trans_a,
19
    bool trans_b
20
)
21
    : MathNode(
×
NEW
22
          element_id,
×
NEW
23
          debug_info,
×
NEW
24
          vertex,
×
NEW
25
          parent,
×
26
          LibraryNodeType_Gemm,
NEW
27
          {"Y"},
×
NEW
28
          {"A", "B", "C"},
×
29
          data_flow::ImplementationType_NONE
30
      ),
NEW
31
      alpha_(alpha), beta_(beta), trans_a_(trans_a), trans_b_(trans_b) {}
×
32

33
void GemmNode::validate(const Function &) const { /* TODO */ }
×
34

35
bool GemmNode::expand(builder::StructuredSDFGBuilder &builder, analysis::AnalysisManager &analysis_manager) {
×
36
    auto &dataflow = this->get_parent();
×
37
    auto &block = static_cast<structured_control_flow::Block &>(*dataflow.get_parent());
×
38

39
    auto &scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
40
    auto &parent = static_cast<structured_control_flow::Sequence &>(*scope_analysis.parent_scope(&block));
×
41

42
    // Locate edges
43
    const data_flow::Memlet *iedge_A = nullptr;
×
44
    const data_flow::Memlet *iedge_B = nullptr;
×
45
    const data_flow::Memlet *iedge_C = nullptr;
×
46
    const data_flow::Memlet *oedge_Y = nullptr;
×
47
    for (const auto &edge : dataflow.in_edges(*this)) {
×
48
        if (edge.dst_conn() == "A") {
×
49
            iedge_A = &edge;
×
50
        }
×
51
        if (edge.dst_conn() == "B") {
×
52
            iedge_B = &edge;
×
53
        }
×
54
        if (edge.dst_conn() == "C") {
×
55
            iedge_C = &edge;
×
56
        }
×
57
    }
58
    for (const auto &edge : dataflow.out_edges(*this)) {
×
59
        if (edge.src_conn() == "Y") {
×
60
            oedge_Y = &edge;
×
61
        }
×
62
    }
63
    if (!iedge_A || !iedge_B || !oedge_Y) return false;
×
64

65
    bool has_C_in = iedge_C != nullptr;
×
66

67
    std::string A_name = static_cast<const data_flow::AccessNode &>(iedge_A->src()).data();
×
68
    std::string B_name = static_cast<const data_flow::AccessNode &>(iedge_B->src()).data();
×
69
    std::string C_in_name = has_C_in ? static_cast<const data_flow::AccessNode &>(iedge_C->src()).data() : "";
×
70
    std::string C_out_name = static_cast<const data_flow::AccessNode &>(oedge_Y->dst()).data();
×
71

72
    // Create new sequence before
73
    auto &new_sequence = builder.add_sequence_before(parent, block, block.debug_info()).first;
×
74
    structured_control_flow::Sequence *last_scope = &new_sequence;
×
75

76
    // Create maps over output subset dims (parallel dims)
77
    data_flow::Subset domain_begin = {
×
78
        symbolic::integer(0),
×
79
        symbolic::integer(0),
×
80
        symbolic::integer(0),
×
81
    };
82
    data_flow::Subset domain_end = {
×
83
        oedge_Y->end_subset()[0],
×
84
        oedge_Y->end_subset()[1],
×
85
        trans_a_ ? iedge_A->end_subset()[1] : iedge_A->end_subset()[0],
×
86
    };
87

88
    std::vector<symbolic::Expression> out_syms;
×
89
    structured_control_flow::Map *last_map = nullptr;
×
90
    for (size_t d = 0; d < domain_begin.size(); ++d) {
×
91
        std::string indvar_str = builder.find_new_name("_i");
×
92
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
×
93
        auto indvar = symbolic::symbol(indvar_str);
×
94
        auto init = domain_begin[d];
×
95
        auto update = symbolic::add(indvar, symbolic::one());
×
96
        auto cond = symbolic::Lt(indvar, symbolic::add(domain_end[d], symbolic::one()));
×
97
        last_map = &builder.add_map(
×
98
            *last_scope,
×
99
            indvar,
100
            cond,
101
            init,
102
            update,
103
            structured_control_flow::ScheduleType_Sequential,
104
            {},
×
105
            block.debug_info()
×
106
        );
107
        last_scope = &last_map->root();
×
108
        out_syms.push_back(indvar);
×
109
    }
×
110

111
    // Create innermost block
112
    auto &code_block = builder.add_block(*last_scope);
×
NEW
113
    auto &tasklet =
×
NEW
114
        builder
×
NEW
115
            .add_tasklet(code_block, data_flow::TaskletCode::fma, "_out", {"_in1", "_in2", "_in3"}, block.debug_info());
×
116

117
    auto &A_in = builder.add_access(code_block, A_name, block.debug_info());
×
118
    auto &B_in = builder.add_access(code_block, B_name, block.debug_info());
×
119
    auto &C_in = builder.add_access(code_block, has_C_in ? C_in_name : C_out_name, block.debug_info());
×
120
    auto &C_out = builder.add_access(code_block, C_out_name, block.debug_info());
×
121

122
    data_flow::Subset subset_A;
×
123
    if (trans_a_) {
×
124
        subset_A = {out_syms[1], out_syms[0]};
×
125
    } else {
×
126
        subset_A = {out_syms[0], out_syms[1]};
×
127
    }
128
    data_flow::Subset subset_B;
×
129
    if (trans_b_) {
×
130
        subset_B = {out_syms[1], out_syms[0]};
×
131
    } else {
×
132
        subset_B = {out_syms[0], out_syms[1]};
×
133
    }
134
    data_flow::Subset subset_C = {out_syms[0], out_syms[1]};
×
135

NEW
136
    builder
×
NEW
137
        .add_computational_memlet(code_block, A_in, tasklet, "_in1", subset_A, iedge_A->base_type(), block.debug_info());
×
NEW
138
    builder
×
NEW
139
        .add_computational_memlet(code_block, B_in, tasklet, "_in2", subset_B, iedge_B->base_type(), block.debug_info());
×
NEW
140
    builder
×
NEW
141
        .add_computational_memlet(code_block, C_in, tasklet, "_in3", subset_C, oedge_Y->base_type(), block.debug_info());
×
NEW
142
    builder
×
NEW
143
        .add_computational_memlet(code_block, tasklet, "_out", C_out, subset_C, oedge_Y->base_type(), block.debug_info());
×
144

145
    // Cleanup old block
146
    builder.remove_memlet(block, *iedge_A);
×
147
    builder.remove_memlet(block, *iedge_B);
×
148
    if (has_C_in) {
×
149
        builder.remove_memlet(block, *iedge_C);
×
150
    }
×
151
    builder.remove_memlet(block, *oedge_Y);
×
152
    builder.remove_node(block, *this);
×
153
    builder.remove_child(parent, block);
×
154

155
    return true;
×
156
}
×
157

158
std::unique_ptr<data_flow::DataFlowNode> GemmNode::
159
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph &parent) const {
×
NEW
160
    return std::unique_ptr<data_flow::DataFlowNode>(
×
NEW
161
        new GemmNode(element_id, this->debug_info(), vertex, parent, alpha_, beta_, trans_a_, trans_b_)
×
162
    );
UNCOV
163
}
×
164

165
nlohmann::json GemmNodeSerializer::serialize(const data_flow::LibraryNode &library_node) {
×
166
    const GemmNode &node = static_cast<const GemmNode &>(library_node);
×
167
    nlohmann::json j;
×
168

169
    j["code"] = node.code().value();
×
170
    j["alpha"] = node.alpha();
×
171
    j["beta"] = node.beta();
×
172
    j["trans_a"] = node.trans_a();
×
173
    j["trans_b"] = node.trans_b();
×
174

175
    return j;
×
176
}
×
177

178
data_flow::LibraryNode &GemmNodeSerializer::deserialize(
×
179
    const nlohmann::json &j, builder::StructuredSDFGBuilder &builder, structured_control_flow::Block &parent
180
) {
181
    auto code = j["code"].get<std::string>();
×
182
    if (code != LibraryNodeType_Gemm.value()) {
×
183
        throw std::runtime_error("Invalid library node code");
×
184
    }
185

186
    sdfg::serializer::JSONSerializer serializer;
×
NEW
187
    DebugInfoRegion debug_info = serializer.json_to_debug_info_region(j["debug_info_region"], builder.debug_info());
×
188

189
    auto alpha = j["alpha"].get<std::string>();
×
190
    auto beta = j["beta"].get<std::string>();
×
191
    auto trans_a = j["trans_a"].get<bool>();
×
192
    auto trans_b = j["trans_b"].get<bool>();
×
193

194
    return builder.add_library_node<GemmNode>(parent, debug_info, alpha, beta, trans_a, trans_b);
×
195
}
×
196

197
} // namespace ml
198
} // namespace math
199
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc