• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20561540736

29 Dec 2025 12:13AM UTC coverage: 40.366% (+1.4%) from 38.976%
20561540736

push

github

web-flow
Merge pull request #409 from daisytuner/lib-nodes-refactor

restructures library nodes

14298 of 45900 branches covered (31.15%)

Branch coverage included in aggregate %.

259 of 388 new or added lines in 19 files covered. (66.75%)

28 existing lines in 2 files now uncovered.

12247 of 19861 relevant lines covered (61.66%)

89.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

17.08
/src/data_flow/library_nodes/math/blas/dot_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/dot_node.h"
2
#include <stdexcept>
3
#include <string>
4

5
#include "sdfg/analysis/analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7

8
#include "sdfg/analysis/scope_analysis.h"
9
#include "sdfg/symbolic/symbolic.h"
10

11
namespace sdfg {
12
namespace math {
13
namespace blas {
14

15
DotNode::DotNode(
1✔
16
    size_t element_id,
17
    const DebugInfo& debug_info,
18
    const graph::Vertex vertex,
19
    data_flow::DataFlowGraph& parent,
20
    const data_flow::ImplementationType& implementation_type,
21
    const BLAS_Precision& precision,
22
    symbolic::Expression n,
23
    symbolic::Expression incx,
24
    symbolic::Expression incy
25
)
26
    : BLASNode(
1!
27
          element_id, debug_info, vertex, parent, LibraryNodeType_DOT, {"_out"}, {"x", "y"}, implementation_type, precision
1!
28
      ),
29
      n_(n), incx_(incx), incy_(incy) {}
1!
30

31
symbolic::Expression DotNode::n() const { return this->n_; };
×
32

33
symbolic::Expression DotNode::incx() const { return this->incx_; };
×
34

35
symbolic::Expression DotNode::incy() const { return this->incy_; };
×
36

NEW
37
symbolic::SymbolSet DotNode::symbols() const {
×
NEW
38
    symbolic::SymbolSet syms;
×
39

NEW
40
    for (auto& atom : symbolic::atoms(this->n_)) {
×
NEW
41
        syms.insert(atom);
×
42
    }
NEW
43
    for (auto& atom : symbolic::atoms(this->incx_)) {
×
NEW
44
        syms.insert(atom);
×
45
    }
NEW
46
    for (auto& atom : symbolic::atoms(this->incy_)) {
×
NEW
47
        syms.insert(atom);
×
48
    }
49

NEW
50
    return syms;
×
NEW
51
};
×
52

53
void DotNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
54
    this->n_ = symbolic::subs(this->n_, old_expression, new_expression);
×
55
    this->incx_ = symbolic::subs(this->incx_, old_expression, new_expression);
×
56
    this->incy_ = symbolic::subs(this->incy_, old_expression, new_expression);
×
NEW
57
};
×
58

59
void DotNode::validate(const Function& function) const {}
×
60

61
bool DotNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
62
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
63

64
    auto& dataflow = this->get_parent();
1✔
65
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
66
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
1✔
67
    int index = parent.index(block);
1✔
68
    auto& transition = parent.at(index).second;
1✔
69

70
    const data_flow::Memlet* iedge_x = nullptr;
1✔
71
    const data_flow::Memlet* iedge_y = nullptr;
1✔
72
    for (const auto& iedge : dataflow.in_edges(*this)) {
3✔
73
        if (iedge.dst_conn() == "x") {
2✔
74
            iedge_x = &iedge;
1✔
75
        } else if (iedge.dst_conn() == "y") {
2!
76
            iedge_y = &iedge;
1✔
77
        }
1✔
78
    }
79

80
    const data_flow::Memlet* oedge_res = nullptr;
1✔
81
    for (const auto& oedge : dataflow.out_edges(*this)) {
1!
82
        if (oedge.src_conn() == "_out") {
1!
83
            oedge_res = &oedge;
1✔
84
            break;
1✔
85
        }
86
    }
87

88
    // Check if legal
89
    auto& input_node_x = static_cast<const data_flow::AccessNode&>(iedge_x->src());
1✔
90
    auto& input_node_y = static_cast<const data_flow::AccessNode&>(iedge_y->src());
1✔
91
    auto& output_node_res = static_cast<const data_flow::AccessNode&>(oedge_res->dst());
1✔
92
    if (dataflow.in_degree(input_node_x) != 0 || dataflow.in_degree(input_node_y) != 0 ||
1!
93
        dataflow.out_degree(output_node_res) != 0) {
1✔
94
        return false;
×
95
    }
96

97
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
1✔
98

99
    std::string loop_var = builder.find_new_name("_i");
1!
100
    builder.add_container(loop_var, types::Scalar(types::PrimitiveType::UInt64));
1!
101

102
    auto loop_indvar = symbolic::symbol(loop_var);
1!
103
    auto loop_init = symbolic::integer(0);
1!
104
    auto loop_condition = symbolic::Lt(loop_indvar, this->n_);
1!
105
    auto loop_update = symbolic::add(loop_indvar, symbolic::integer(1));
1!
106

107
    auto& loop =
1✔
108
        builder.add_for(new_sequence, loop_indvar, loop_condition, loop_init, loop_update, {}, block.debug_info());
1!
109
    auto& body = loop.root();
1!
110

111
    auto& new_block = builder.add_block(body);
1!
112

113
    auto& res_in = builder.add_access(new_block, output_node_res.data());
1!
114
    auto& res_out = builder.add_access(new_block, output_node_res.data());
1!
115
    auto& x = builder.add_access(new_block, input_node_x.data());
1!
116
    auto& y = builder.add_access(new_block, input_node_y.data());
1!
117

118
    auto& tasklet = builder.add_tasklet(new_block, data_flow::TaskletCode::fp_fma, "_out", {"_in1", "_in2", "_in3"});
1!
119

120
    builder.add_computational_memlet(
2!
121
        new_block,
1✔
122
        x,
1✔
123
        tasklet,
1✔
124
        "_in1",
1!
125
        {symbolic::mul(loop_indvar, this->incx_)},
1!
126
        iedge_x->base_type(),
1!
127
        iedge_x->debug_info()
1!
128
    );
129
    builder.add_computational_memlet(
2!
130
        new_block,
1✔
131
        y,
1✔
132
        tasklet,
1✔
133
        "_in2",
1!
134
        {symbolic::mul(loop_indvar, this->incy_)},
1!
135
        iedge_y->base_type(),
1!
136
        iedge_y->debug_info()
1!
137
    );
138
    builder
2✔
139
        .add_computational_memlet(new_block, res_in, tasklet, "_in3", {}, oedge_res->base_type(), oedge_res->debug_info());
1!
140
    builder.add_computational_memlet(
2!
141
        new_block, tasklet, "_out", res_out, {}, oedge_res->base_type(), oedge_res->debug_info()
1!
142
    );
143

144
    // Clean up
145
    builder.remove_memlet(block, *iedge_x);
1!
146
    builder.remove_memlet(block, *iedge_y);
1!
147
    builder.remove_memlet(block, *oedge_res);
1!
148
    builder.remove_node(block, input_node_x);
1!
149
    builder.remove_node(block, input_node_y);
1!
150
    builder.remove_node(block, output_node_res);
1!
151
    builder.remove_node(block, *this);
1!
152
    builder.remove_child(parent, index + 1);
1!
153

154
    return true;
1✔
155
}
1✔
156

157
symbolic::Expression DotNode::flop() const {
×
158
    auto muls = this->n_;
×
159
    auto adds = symbolic::sub(this->n_, symbolic::one());
×
160
    return symbolic::add(muls, adds);
×
161
}
×
162

163
std::unique_ptr<data_flow::DataFlowNode> DotNode::
164
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
165
    auto node_clone = std::unique_ptr<DotNode>(new DotNode(
×
166
        element_id,
×
167
        this->debug_info(),
×
168
        vertex,
×
169
        parent,
×
170
        this->implementation_type_,
×
171
        this->precision_,
×
172
        this->n_,
×
173
        this->incx_,
×
174
        this->incy_
×
175
    ));
176
    return std::move(node_clone);
×
177
}
×
178

179
nlohmann::json DotNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
180
    const DotNode& gemm_node = static_cast<const DotNode&>(library_node);
×
181
    nlohmann::json j;
×
182

183
    serializer::JSONSerializer serializer;
×
184
    j["code"] = gemm_node.code().value();
×
185
    j["precision"] = gemm_node.precision();
×
186
    j["n"] = serializer.expression(gemm_node.n());
×
187
    j["incx"] = serializer.expression(gemm_node.incx());
×
188
    j["incy"] = serializer.expression(gemm_node.incy());
×
189

190
    return j;
×
191
}
×
192

193
data_flow::LibraryNode& DotNodeSerializer::deserialize(
×
194
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
195
) {
196
    // Assertions for required fields
197
    assert(j.contains("element_id"));
×
198
    assert(j.contains("code"));
×
199
    assert(j.contains("debug_info"));
×
200

201
    auto code = j["code"].get<std::string>();
×
202
    if (code != LibraryNodeType_DOT.value()) {
×
203
        throw std::runtime_error("Invalid library node code");
×
204
    }
205

206
    // Extract debug info using JSONSerializer
207
    sdfg::serializer::JSONSerializer serializer;
×
208
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
209

210
    auto precision = j.at("precision").get<BLAS_Precision>();
×
211
    auto n = symbolic::parse(j.at("n"));
×
212
    auto incx = symbolic::parse(j.at("incx"));
×
213
    auto incy = symbolic::parse(j.at("incy"));
×
214

215
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
216

217
    return builder.add_library_node<DotNode>(parent, debug_info, implementation_type, precision, n, incx, incy);
×
218
}
×
219

220
DotNodeDispatcher_BLAS::DotNodeDispatcher_BLAS(
×
221
    codegen::LanguageExtension& language_extension,
222
    const Function& function,
223
    const data_flow::DataFlowGraph& data_flow_graph,
224
    const DotNode& node
225
)
226
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
227

228
void DotNodeDispatcher_BLAS::dispatch_code(
×
229
    codegen::PrettyPrinter& stream,
230
    codegen::PrettyPrinter& globals_stream,
231
    codegen::CodeSnippetFactory& library_snippet_factory
232
) {
233
    stream << "{" << std::endl;
×
234
    stream.setIndent(stream.indent() + 4);
×
235

236
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
237

238
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
239
    BLAS_Precision precision = dot_node.precision();
×
240
    switch (precision) {
×
241
        case BLAS_Precision::h:
242
            base_type = types::Scalar(types::PrimitiveType::Half);
×
243
            break;
×
244
        case BLAS_Precision::s:
245
            base_type = types::Scalar(types::PrimitiveType::Float);
×
246
            break;
×
247
        case BLAS_Precision::d:
248
            base_type = types::Scalar(types::PrimitiveType::Double);
×
249
            break;
×
250
        default:
251
            throw std::runtime_error("Invalid BLAS_Precision value");
×
252
    }
253

254
    stream << dot_node.outputs().at(0) << " = ";
×
255
    stream << "cblas_" << BLAS_Precision_to_string(precision) << "dot(";
×
256
    stream.setIndent(stream.indent() + 4);
×
257
    stream << this->language_extension_.expression(dot_node.n());
×
258
    stream << ", ";
×
259
    stream << dot_node.inputs().at(0);
×
260
    stream << ", ";
×
261
    stream << this->language_extension_.expression(dot_node.incx());
×
262
    stream << ", ";
×
263
    stream << dot_node.inputs().at(1);
×
264
    stream << ", ";
×
265
    stream << this->language_extension_.expression(dot_node.incy());
×
266
    stream.setIndent(stream.indent() - 4);
×
267
    stream << ");" << std::endl;
×
268

269
    stream.setIndent(stream.indent() - 4);
×
270
    stream << "}" << std::endl;
×
271
}
×
272

273
DotNodeDispatcher_CUBLASWithTransfers::DotNodeDispatcher_CUBLASWithTransfers(
×
274
    codegen::LanguageExtension& language_extension,
275
    const Function& function,
276
    const data_flow::DataFlowGraph& data_flow_graph,
277
    const DotNode& node
278
)
279
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
280

281
void DotNodeDispatcher_CUBLASWithTransfers::dispatch_code(
×
282
    codegen::PrettyPrinter& stream,
283
    codegen::PrettyPrinter& globals_stream,
284
    codegen::CodeSnippetFactory& library_snippet_factory
285
) {
286
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
287

288
    globals_stream << "#include <cuda.h>" << std::endl;
×
289
    globals_stream << "#include <cublas_v2.h>" << std::endl;
×
290

291
    std::string type, type2;
×
292
    switch (dot_node.precision()) {
×
293
        case s:
294
            type = "float";
×
295
            type2 = "S";
×
296
            break;
×
297
        case d:
298
            type = "double";
×
299
            type2 = "D";
×
300
            break;
×
301
        default:
302
            throw std::runtime_error("Invalid precision for CUBLAS DOT node");
×
303
    }
304

305
    const std::string x_size =
306
        this->language_extension_.expression(
×
307
            symbolic::add(symbolic::mul(symbolic::sub(dot_node.n(), symbolic::one()), dot_node.incx()), symbolic::one())
×
308
        ) +
×
309
        " * sizeof(" + type + ")";
×
310
    const std::string y_size =
311
        this->language_extension_.expression(
×
312
            symbolic::add(symbolic::mul(symbolic::sub(dot_node.n(), symbolic::one()), dot_node.incy()), symbolic::one())
×
313
        ) +
×
314
        " * sizeof(" + type + ")";
×
315

316
    stream << type << " *dx, *dy;" << std::endl;
×
317
    stream << "cudaMalloc(&dx, " << x_size << ");" << std::endl;
×
318
    stream << "cudaMalloc(&dy, " << y_size << ");" << std::endl;
×
319

320
    stream << "cudaMemcpy(dx, x, " << x_size << ", cudaMemcpyHostToDevice);" << std::endl;
×
321
    stream << "cudaMemcpy(dy, y, " << y_size << ", cudaMemcpyHostToDevice);" << std::endl;
×
322

323
    stream << "cublasStatus_t err;" << std::endl;
×
324
    stream << "cublasHandle_t handle;" << std::endl;
×
325
    stream << "err = cublasCreate(&handle);" << std::endl;
×
326
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
327
    stream.setIndent(stream.indent() + 4);
×
328
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
329
    stream.setIndent(stream.indent() - 4);
×
330
    stream << "}" << std::endl;
×
331
    stream << "err = cublas" << type2 << "dot(handle, " << this->language_extension_.expression(dot_node.n())
×
332
           << ", dx, " << this->language_extension_.expression(dot_node.incx()) << ", dy, "
×
333
           << this->language_extension_.expression(dot_node.incy()) << ", &_out);" << std::endl;
×
334
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
335
    stream.setIndent(stream.indent() + 4);
×
336
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
337
    stream.setIndent(stream.indent() - 4);
×
338
    stream << "}" << std::endl;
×
339
    stream << "err = cublasDestroy(handle);" << std::endl;
×
340
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
341
    stream.setIndent(stream.indent() + 4);
×
342
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
343
    stream.setIndent(stream.indent() - 4);
×
344
    stream << "}" << std::endl;
×
345

346
    stream << "cudaFree(dx);" << std::endl;
×
347
    stream << "cudaFree(dy);" << std::endl;
×
348
}
×
349

350
DotNodeDispatcher_CUBLASWithoutTransfers::DotNodeDispatcher_CUBLASWithoutTransfers(
×
351
    codegen::LanguageExtension& language_extension,
352
    const Function& function,
353
    const data_flow::DataFlowGraph& data_flow_graph,
354
    const DotNode& node
355
)
356
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
357

358
void DotNodeDispatcher_CUBLASWithoutTransfers::dispatch_code(
×
359
    codegen::PrettyPrinter& stream,
360
    codegen::PrettyPrinter& globals_stream,
361
    codegen::CodeSnippetFactory& library_snippet_factory
362
) {
363
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
364

365
    globals_stream << "#include <cuda.h>" << std::endl;
×
366
    globals_stream << "#include <cublas_v2.h>" << std::endl;
×
367

368
    stream << "cublasStatus_t err;" << std::endl;
×
369
    stream << "cublasHandle_t handle;" << std::endl;
×
370
    stream << "err = cublasCreate(&handle);" << std::endl;
×
371
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
372
    stream.setIndent(stream.indent() + 4);
×
373
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
374
    stream.setIndent(stream.indent() - 4);
×
375
    stream << "}" << std::endl;
×
376
    stream << "err = cublas";
×
377
    switch (dot_node.precision()) {
×
378
        case s:
379
            stream << "S";
×
380
            break;
×
381
        case d:
382
            stream << "D";
×
383
            break;
×
384
        default:
385
            throw std::runtime_error("Invalid precision for CUBLAS DOT node");
×
386
    }
387
    stream << "dot(handle, " << this->language_extension_.expression(dot_node.n()) << ", x, "
×
388
           << this->language_extension_.expression(dot_node.incx()) << ", y, "
×
389
           << this->language_extension_.expression(dot_node.incy()) << ", &_out);" << std::endl;
×
390
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
391
    stream.setIndent(stream.indent() + 4);
×
392
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
393
    stream.setIndent(stream.indent() - 4);
×
394
    stream << "}" << std::endl;
×
395
    stream << "err = cublasDestroy(handle);" << std::endl;
×
396
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
397
    stream.setIndent(stream.indent() + 4);
×
398
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
399
    stream.setIndent(stream.indent() - 4);
×
400
    stream << "}" << std::endl;
×
401
}
×
402

403
} // namespace blas
404
} // namespace math
405
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc