• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 17651658650

11 Sep 2025 04:58PM UTC coverage: 61.012% (+1.3%) from 59.755%
17651658650

Pull #219

github

web-flow
Merge 742a12367 into f744ac9f5
Pull Request #219: stdlib Library Nodes and ConstantNodes

499 of 1681 new or added lines in 81 files covered. (29.68%)

95 existing lines in 36 files now uncovered.

9718 of 15928 relevant lines covered (61.01%)

108.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.02
/src/data_flow/library_nodes/math/blas/dot.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/dot.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
#include "sdfg/analysis/scope_analysis.h"
7

8
namespace sdfg {
9
namespace math {
10
namespace blas {
11

12
DotNode::DotNode(
1✔
13
    size_t element_id,
14
    const DebugInfoRegion& debug_info,
15
    const graph::Vertex vertex,
16
    data_flow::DataFlowGraph& parent,
17
    const data_flow::ImplementationType& implementation_type,
18
    const BLAS_Precision& precision,
19
    symbolic::Expression n,
20
    symbolic::Expression incx,
21
    symbolic::Expression incy
22
)
23
    : MathNode(element_id, debug_info, vertex, parent, LibraryNodeType_DOT, {"_out"}, {"x", "y"}, implementation_type),
1✔
24
      precision_(precision), n_(n), incx_(incx), incy_(incy) {}
1✔
25

26
BLAS_Precision DotNode::precision() const { return this->precision_; };
×
27

28
symbolic::Expression DotNode::n() const { return this->n_; };
×
29

30
symbolic::Expression DotNode::incx() const { return this->incx_; };
×
31

32
symbolic::Expression DotNode::incy() const { return this->incy_; };
×
33

NEW
34
void DotNode::validate(const Function& function) const {}
×
35

36
bool DotNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
37
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
38

39
    auto& dataflow = this->get_parent();
1✔
40
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
41
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
1✔
42
    int index = parent.index(block);
1✔
43
    auto& transition = parent.at(index).second;
1✔
44

45
    const data_flow::Memlet* iedge_x = nullptr;
1✔
46
    const data_flow::Memlet* iedge_y = nullptr;
1✔
47
    for (const auto& iedge : dataflow.in_edges(*this)) {
3✔
48
        if (iedge.dst_conn() == "x") {
2✔
49
            iedge_x = &iedge;
1✔
50
        } else if (iedge.dst_conn() == "y") {
2✔
51
            iedge_y = &iedge;
1✔
52
        }
1✔
53
    }
54

55
    const data_flow::Memlet* oedge_res = nullptr;
1✔
56
    for (const auto& oedge : dataflow.out_edges(*this)) {
1✔
57
        if (oedge.src_conn() == "_out") {
1✔
58
            oedge_res = &oedge;
1✔
59
            break;
1✔
60
        }
61
    }
62

63
    // Check if legal
64
    auto& input_node_x = static_cast<const data_flow::AccessNode&>(iedge_x->src());
1✔
65
    auto& input_node_y = static_cast<const data_flow::AccessNode&>(iedge_y->src());
1✔
66
    auto& output_node_res = static_cast<const data_flow::AccessNode&>(oedge_res->dst());
1✔
67
    if (dataflow.in_degree(input_node_x) != 0 || dataflow.in_degree(input_node_y) != 0 ||
1✔
68
        dataflow.out_degree(output_node_res) != 0) {
1✔
69
        return false;
×
70
    }
71

72
    auto& new_sequence = builder.add_sequence_before(
1✔
73
        parent, block, transition.assignments(), builder.debug_info().get_region(block.debug_info().indices())
1✔
74
    );
75

76
    std::string loop_var = builder.find_new_name("_i");
1✔
77
    builder.add_container(loop_var, types::Scalar(types::PrimitiveType::UInt64));
1✔
78

79
    auto loop_indvar = symbolic::symbol(loop_var);
1✔
80
    auto loop_init = symbolic::integer(0);
1✔
81
    auto loop_condition = symbolic::Lt(loop_indvar, this->n_);
1✔
82
    auto loop_update = symbolic::add(loop_indvar, symbolic::integer(1));
1✔
83

84
    auto& loop = builder.add_for(
2✔
85
        new_sequence,
1✔
86
        loop_indvar,
87
        loop_condition,
88
        loop_init,
1✔
89
        loop_update,
90
        {},
1✔
91
        builder.subject().debug_info().get_region(block.debug_info().indices())
1✔
92
    );
93
    auto& body = loop.root();
1✔
94

95
    auto& new_block = builder.add_block(body);
1✔
96

97
    auto& res_in = builder.add_access(new_block, output_node_res.data());
1✔
98
    auto& res_out = builder.add_access(new_block, output_node_res.data());
1✔
99
    auto& x = builder.add_access(new_block, input_node_x.data());
1✔
100
    auto& y = builder.add_access(new_block, input_node_y.data());
1✔
101

102
    auto& tasklet = builder.add_tasklet(new_block, data_flow::TaskletCode::fma, "_out", {"_in1", "_in2", "_in3"});
1✔
103

104
    builder.add_computational_memlet(
2✔
105
        new_block,
1✔
106
        x,
1✔
107
        tasklet,
1✔
108
        "_in1",
1✔
109
        {symbolic::mul(loop_indvar, this->incx_)},
1✔
110
        iedge_x->base_type(),
1✔
111
        builder.subject().debug_info().get_region(iedge_x->debug_info().indices())
1✔
112
    );
113
    builder.add_computational_memlet(
2✔
114
        new_block,
1✔
115
        y,
1✔
116
        tasklet,
1✔
117
        "_in2",
1✔
118
        {symbolic::mul(loop_indvar, this->incy_)},
1✔
119
        iedge_y->base_type(),
1✔
120
        builder.subject().debug_info().get_region(iedge_y->debug_info().indices())
1✔
121
    );
122
    builder.add_computational_memlet(
2✔
123
        new_block,
1✔
124
        res_in,
1✔
125
        tasklet,
1✔
126
        "_in3",
1✔
127
        {},
1✔
128
        oedge_res->base_type(),
1✔
129
        builder.subject().debug_info().get_region(oedge_res->debug_info().indices())
1✔
130
    );
131
    builder.add_computational_memlet(
2✔
132
        new_block,
1✔
133
        tasklet,
1✔
134
        "_out",
1✔
135
        res_out,
1✔
136
        {},
1✔
137
        oedge_res->base_type(),
1✔
138
        builder.subject().debug_info().get_region(oedge_res->debug_info().indices())
1✔
139
    );
140

141
    // Clean up
142
    builder.remove_memlet(block, *iedge_x);
1✔
143
    builder.remove_memlet(block, *iedge_y);
1✔
144
    builder.remove_memlet(block, *oedge_res);
1✔
145
    builder.remove_node(block, input_node_x);
1✔
146
    builder.remove_node(block, input_node_y);
1✔
147
    builder.remove_node(block, output_node_res);
1✔
148
    builder.remove_node(block, *this);
1✔
149
    builder.remove_child(parent, index + 1);
1✔
150

151
    return true;
1✔
152
}
1✔
153

154
std::unique_ptr<data_flow::DataFlowNode> DotNode::
155
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
156
    auto node_clone = std::unique_ptr<DotNode>(new DotNode(
×
157
        element_id,
×
158
        this->debug_info(),
×
159
        vertex,
×
160
        parent,
×
161
        this->implementation_type_,
×
162
        this->precision_,
×
163
        this->n_,
×
164
        this->incx_,
×
165
        this->incy_
×
166
    ));
167
    return std::move(node_clone);
×
168
}
×
169

170
nlohmann::json DotNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
171
    const DotNode& gemm_node = static_cast<const DotNode&>(library_node);
×
172
    nlohmann::json j;
×
173

174
    serializer::JSONSerializer serializer;
×
175
    j["code"] = gemm_node.code().value();
×
176
    j["precision"] = gemm_node.precision();
×
177
    j["n"] = serializer.expression(gemm_node.n());
×
178
    j["incx"] = serializer.expression(gemm_node.incx());
×
179
    j["incy"] = serializer.expression(gemm_node.incy());
×
180

181
    return j;
×
182
}
×
183

184
data_flow::LibraryNode& DotNodeSerializer::deserialize(
×
185
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
186
) {
187
    // Assertions for required fields
188
    assert(j.contains("element_id"));
×
189
    assert(j.contains("code"));
×
190
    assert(j.contains("debug_info"));
×
191

192
    auto code = j["code"].get<std::string>();
×
193
    if (code != LibraryNodeType_DOT.value()) {
×
194
        throw std::runtime_error("Invalid library node code");
×
195
    }
196

197
    // Extract debug info using JSONSerializer
198
    sdfg::serializer::JSONSerializer serializer;
×
199
    DebugInfoRegion debug_info = serializer.json_to_debug_info_region(j["debug_info"], builder.debug_info());
×
200

201
    auto precision = j.at("precision").get<BLAS_Precision>();
×
202
    auto n = SymEngine::Expression(j.at("n"));
×
203
    auto incx = SymEngine::Expression(j.at("incx"));
×
204
    auto incy = SymEngine::Expression(j.at("incy"));
×
205

206
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
207

208
    return builder.add_library_node<DotNode>(parent, debug_info, implementation_type, precision, n, incx, incy);
×
209
}
×
210

211
DotNodeDispatcher_BLAS::DotNodeDispatcher_BLAS(
×
212
    codegen::LanguageExtension& language_extension,
213
    const Function& function,
214
    const data_flow::DataFlowGraph& data_flow_graph,
215
    const DotNode& node
216
)
217
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
218

NEW
219
void DotNodeDispatcher_BLAS::dispatch_code(
×
220
    codegen::PrettyPrinter& stream,
221
    codegen::PrettyPrinter& globals_stream,
222
    codegen::CodeSnippetFactory& library_snippet_factory
223
) {
224
    stream << "{" << std::endl;
×
225
    stream.setIndent(stream.indent() + 4);
×
226

227
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
228

229
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
230
    switch (dot_node.precision()) {
×
231
        case BLAS_Precision::h:
232
            base_type = types::Scalar(types::PrimitiveType::Half);
×
233
            break;
×
234
        case BLAS_Precision::s:
235
            base_type = types::Scalar(types::PrimitiveType::Float);
×
236
            break;
×
237
        case BLAS_Precision::d:
238
            base_type = types::Scalar(types::PrimitiveType::Double);
×
239
            break;
×
240
        default:
241
            throw std::runtime_error("Invalid BLAS_Precision value");
×
242
    }
243

NEW
244
    stream << "res = ";
×
245
    stream << "cblas_" << BLAS_Precision_to_string(dot_node.precision()) << "dot(";
×
246
    stream.setIndent(stream.indent() + 4);
×
247
    stream << this->language_extension_.expression(dot_node.n());
×
248
    stream << ", ";
×
249
    stream << "x";
×
250
    stream << ", ";
×
251
    stream << this->language_extension_.expression(dot_node.incx());
×
252
    stream << ", ";
×
253
    stream << "y";
×
254
    stream << ", ";
×
255
    stream << this->language_extension_.expression(dot_node.incy());
×
256
    stream.setIndent(stream.indent() - 4);
×
257
    stream << ");" << std::endl;
×
258

259
    stream.setIndent(stream.indent() - 4);
×
260
    stream << "}" << std::endl;
×
261
}
×
262

263
DotNodeDispatcher_CUBLAS::DotNodeDispatcher_CUBLAS(
×
264
    codegen::LanguageExtension& language_extension,
265
    const Function& function,
266
    const data_flow::DataFlowGraph& data_flow_graph,
267
    const DotNode& node
268
)
269
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
270

NEW
271
void DotNodeDispatcher_CUBLAS::dispatch_code(
×
272
    codegen::PrettyPrinter& stream,
273
    codegen::PrettyPrinter& globals_stream,
274
    codegen::CodeSnippetFactory& library_snippet_factory
275
) {
276
    throw std::runtime_error("DotNodeDispatcher_CUBLAS not implemented");
×
277
}
×
278

279
} // namespace blas
280
} // namespace math
281
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc