• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 19696454706

26 Nov 2025 07:54AM UTC coverage: 61.884% (-0.3%) from 62.187%
19696454706

push

github

web-flow
CUBLAS DOT with/without data transfers (#362)

* Split `ImplementationType_CUBLAS` into `ImplementationType_CUBLASWithTransfers` and `ImplementationType_CUBLASWithoutTransfers`
* Implemented dispatchers for CUBLAS DOT
* Added `<cstdint>` to C++ includes

9 of 109 new or added lines in 4 files covered. (8.26%)

2 existing lines in 2 files now uncovered.

11248 of 18176 relevant lines covered (61.88%)

110.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

30.99
/src/data_flow/library_nodes/math/blas/dot.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/dot.h"
2
#include <stdexcept>
3
#include <string>
4

5
#include "sdfg/analysis/analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7

8
#include "sdfg/analysis/scope_analysis.h"
9
#include "sdfg/symbolic/symbolic.h"
10

11
namespace sdfg {
12
namespace math {
13
namespace blas {
14

15
DotNode::DotNode(
1✔
16
    size_t element_id,
17
    const DebugInfo& debug_info,
18
    const graph::Vertex vertex,
19
    data_flow::DataFlowGraph& parent,
20
    const data_flow::ImplementationType& implementation_type,
21
    const BLAS_Precision& precision,
22
    symbolic::Expression n,
23
    symbolic::Expression incx,
24
    symbolic::Expression incy
25
)
26
    : BLASNode(
1✔
27
          element_id, debug_info, vertex, parent, LibraryNodeType_DOT, {"_out"}, {"x", "y"}, implementation_type, precision
1✔
28
      ),
29
      n_(n), incx_(incx), incy_(incy) {}
1✔
30

31
symbolic::Expression DotNode::n() const { return this->n_; };
×
32

33
symbolic::Expression DotNode::incx() const { return this->incx_; };
×
34

35
symbolic::Expression DotNode::incy() const { return this->incy_; };
×
36

37
void DotNode::validate(const Function& function) const {}
×
38

39
bool DotNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
40
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
41

42
    auto& dataflow = this->get_parent();
1✔
43
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
44
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
1✔
45
    int index = parent.index(block);
1✔
46
    auto& transition = parent.at(index).second;
1✔
47

48
    const data_flow::Memlet* iedge_x = nullptr;
1✔
49
    const data_flow::Memlet* iedge_y = nullptr;
1✔
50
    for (const auto& iedge : dataflow.in_edges(*this)) {
3✔
51
        if (iedge.dst_conn() == "x") {
2✔
52
            iedge_x = &iedge;
1✔
53
        } else if (iedge.dst_conn() == "y") {
2✔
54
            iedge_y = &iedge;
1✔
55
        }
1✔
56
    }
57

58
    const data_flow::Memlet* oedge_res = nullptr;
1✔
59
    for (const auto& oedge : dataflow.out_edges(*this)) {
1✔
60
        if (oedge.src_conn() == "_out") {
1✔
61
            oedge_res = &oedge;
1✔
62
            break;
1✔
63
        }
64
    }
65

66
    // Check if legal
67
    auto& input_node_x = static_cast<const data_flow::AccessNode&>(iedge_x->src());
1✔
68
    auto& input_node_y = static_cast<const data_flow::AccessNode&>(iedge_y->src());
1✔
69
    auto& output_node_res = static_cast<const data_flow::AccessNode&>(oedge_res->dst());
1✔
70
    if (dataflow.in_degree(input_node_x) != 0 || dataflow.in_degree(input_node_y) != 0 ||
1✔
71
        dataflow.out_degree(output_node_res) != 0) {
1✔
72
        return false;
×
73
    }
74

75
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
1✔
76

77
    std::string loop_var = builder.find_new_name("_i");
1✔
78
    builder.add_container(loop_var, types::Scalar(types::PrimitiveType::UInt64));
1✔
79

80
    auto loop_indvar = symbolic::symbol(loop_var);
1✔
81
    auto loop_init = symbolic::integer(0);
1✔
82
    auto loop_condition = symbolic::Lt(loop_indvar, this->n_);
1✔
83
    auto loop_update = symbolic::add(loop_indvar, symbolic::integer(1));
1✔
84

85
    auto& loop =
1✔
86
        builder.add_for(new_sequence, loop_indvar, loop_condition, loop_init, loop_update, {}, block.debug_info());
1✔
87
    auto& body = loop.root();
1✔
88

89
    auto& new_block = builder.add_block(body);
1✔
90

91
    auto& res_in = builder.add_access(new_block, output_node_res.data());
1✔
92
    auto& res_out = builder.add_access(new_block, output_node_res.data());
1✔
93
    auto& x = builder.add_access(new_block, input_node_x.data());
1✔
94
    auto& y = builder.add_access(new_block, input_node_y.data());
1✔
95

96
    auto& tasklet = builder.add_tasklet(new_block, data_flow::TaskletCode::fp_fma, "_out", {"_in1", "_in2", "_in3"});
1✔
97

98
    builder.add_computational_memlet(
2✔
99
        new_block,
1✔
100
        x,
1✔
101
        tasklet,
1✔
102
        "_in1",
1✔
103
        {symbolic::mul(loop_indvar, this->incx_)},
1✔
104
        iedge_x->base_type(),
1✔
105
        iedge_x->debug_info()
1✔
106
    );
107
    builder.add_computational_memlet(
2✔
108
        new_block,
1✔
109
        y,
1✔
110
        tasklet,
1✔
111
        "_in2",
1✔
112
        {symbolic::mul(loop_indvar, this->incy_)},
1✔
113
        iedge_y->base_type(),
1✔
114
        iedge_y->debug_info()
1✔
115
    );
116
    builder
2✔
117
        .add_computational_memlet(new_block, res_in, tasklet, "_in3", {}, oedge_res->base_type(), oedge_res->debug_info());
1✔
118
    builder.add_computational_memlet(
2✔
119
        new_block, tasklet, "_out", res_out, {}, oedge_res->base_type(), oedge_res->debug_info()
1✔
120
    );
121

122
    // Clean up
123
    builder.remove_memlet(block, *iedge_x);
1✔
124
    builder.remove_memlet(block, *iedge_y);
1✔
125
    builder.remove_memlet(block, *oedge_res);
1✔
126
    builder.remove_node(block, input_node_x);
1✔
127
    builder.remove_node(block, input_node_y);
1✔
128
    builder.remove_node(block, output_node_res);
1✔
129
    builder.remove_node(block, *this);
1✔
130
    builder.remove_child(parent, index + 1);
1✔
131

132
    return true;
1✔
133
}
1✔
134

135
std::unique_ptr<data_flow::DataFlowNode> DotNode::
136
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
137
    auto node_clone = std::unique_ptr<DotNode>(new DotNode(
×
138
        element_id,
×
139
        this->debug_info(),
×
140
        vertex,
×
141
        parent,
×
142
        this->implementation_type_,
×
143
        this->precision_,
×
144
        this->n_,
×
145
        this->incx_,
×
146
        this->incy_
×
147
    ));
148
    return std::move(node_clone);
×
149
}
×
150

151
nlohmann::json DotNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
152
    const DotNode& gemm_node = static_cast<const DotNode&>(library_node);
×
153
    nlohmann::json j;
×
154

155
    serializer::JSONSerializer serializer;
×
156
    j["code"] = gemm_node.code().value();
×
157
    j["precision"] = gemm_node.precision();
×
158
    j["n"] = serializer.expression(gemm_node.n());
×
159
    j["incx"] = serializer.expression(gemm_node.incx());
×
160
    j["incy"] = serializer.expression(gemm_node.incy());
×
161

162
    return j;
×
163
}
×
164

165
data_flow::LibraryNode& DotNodeSerializer::deserialize(
×
166
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
167
) {
168
    // Assertions for required fields
169
    assert(j.contains("element_id"));
×
170
    assert(j.contains("code"));
×
171
    assert(j.contains("debug_info"));
×
172

173
    auto code = j["code"].get<std::string>();
×
174
    if (code != LibraryNodeType_DOT.value()) {
×
175
        throw std::runtime_error("Invalid library node code");
×
176
    }
177

178
    // Extract debug info using JSONSerializer
179
    sdfg::serializer::JSONSerializer serializer;
×
180
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
181

182
    auto precision = j.at("precision").get<BLAS_Precision>();
×
183
    auto n = symbolic::parse(j.at("n"));
×
184
    auto incx = symbolic::parse(j.at("incx"));
×
185
    auto incy = symbolic::parse(j.at("incy"));
×
186

187
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
188

189
    return builder.add_library_node<DotNode>(parent, debug_info, implementation_type, precision, n, incx, incy);
×
190
}
×
191

192
DotNodeDispatcher_BLAS::DotNodeDispatcher_BLAS(
×
193
    codegen::LanguageExtension& language_extension,
194
    const Function& function,
195
    const data_flow::DataFlowGraph& data_flow_graph,
196
    const DotNode& node
197
)
198
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
199

200
void DotNodeDispatcher_BLAS::dispatch_code(
×
201
    codegen::PrettyPrinter& stream,
202
    codegen::PrettyPrinter& globals_stream,
203
    codegen::CodeSnippetFactory& library_snippet_factory
204
) {
205
    stream << "{" << std::endl;
×
206
    stream.setIndent(stream.indent() + 4);
×
207

208
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
209

210
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
211
    BLAS_Precision precision = dot_node.precision();
×
212
    switch (precision) {
×
213
        case BLAS_Precision::h:
214
            base_type = types::Scalar(types::PrimitiveType::Half);
×
215
            break;
×
216
        case BLAS_Precision::s:
217
            base_type = types::Scalar(types::PrimitiveType::Float);
×
218
            break;
×
219
        case BLAS_Precision::d:
220
            base_type = types::Scalar(types::PrimitiveType::Double);
×
221
            break;
×
222
        default:
223
            throw std::runtime_error("Invalid BLAS_Precision value");
×
224
    }
225

226
    stream << dot_node.outputs().at(0) << " = ";
×
227
    stream << "cblas_" << BLAS_Precision_to_string(precision) << "dot(";
×
228
    stream.setIndent(stream.indent() + 4);
×
229
    stream << this->language_extension_.expression(dot_node.n());
×
230
    stream << ", ";
×
231
    stream << dot_node.inputs().at(0);
×
232
    stream << ", ";
×
233
    stream << this->language_extension_.expression(dot_node.incx());
×
234
    stream << ", ";
×
235
    stream << dot_node.inputs().at(1);
×
236
    stream << ", ";
×
237
    stream << this->language_extension_.expression(dot_node.incy());
×
238
    stream.setIndent(stream.indent() - 4);
×
239
    stream << ");" << std::endl;
×
240

241
    stream.setIndent(stream.indent() - 4);
×
242
    stream << "}" << std::endl;
×
243
}
×
244

NEW
245
DotNodeDispatcher_CUBLASWithTransfers::DotNodeDispatcher_CUBLASWithTransfers(
×
246
    codegen::LanguageExtension& language_extension,
247
    const Function& function,
248
    const data_flow::DataFlowGraph& data_flow_graph,
249
    const DotNode& node
250
)
251
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
252

NEW
253
void DotNodeDispatcher_CUBLASWithTransfers::dispatch_code(
×
254
    codegen::PrettyPrinter& stream,
255
    codegen::PrettyPrinter& globals_stream,
256
    codegen::CodeSnippetFactory& library_snippet_factory
257
) {
NEW
258
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
259

NEW
260
    globals_stream << "#include <cuda.h>" << std::endl;
×
NEW
261
    globals_stream << "#include <cublas_v2.h>" << std::endl;
×
262

NEW
263
    std::string type, type2;
×
NEW
264
    switch (dot_node.precision()) {
×
265
        case s:
NEW
266
            type = "float";
×
NEW
267
            type2 = "S";
×
NEW
268
            break;
×
269
        case d:
NEW
270
            type = "double";
×
NEW
271
            type2 = "D";
×
NEW
272
            break;
×
273
        default:
NEW
274
            throw std::runtime_error("Invalid precision for CUBLAS DOT node");
×
275
    }
276

277
    const std::string x_size =
NEW
278
        this->language_extension_.expression(
×
NEW
279
            symbolic::add(symbolic::mul(symbolic::sub(dot_node.n(), symbolic::one()), dot_node.incx()), symbolic::one())
×
NEW
280
        ) +
×
NEW
281
        " * sizeof(" + type + ")";
×
282
    const std::string y_size =
NEW
283
        this->language_extension_.expression(
×
NEW
284
            symbolic::add(symbolic::mul(symbolic::sub(dot_node.n(), symbolic::one()), dot_node.incy()), symbolic::one())
×
NEW
285
        ) +
×
NEW
286
        " * sizeof(" + type + ")";
×
287

NEW
288
    stream << type << " *dx, *dy;" << std::endl;
×
NEW
289
    stream << "cudaMalloc(&dx, " << x_size << ");" << std::endl;
×
NEW
290
    stream << "cudaMalloc(&dy, " << y_size << ");" << std::endl;
×
291

NEW
292
    stream << "cudaMemcpy(dx, x, " << x_size << ", cudaMemcpyHostToDevice);" << std::endl;
×
NEW
293
    stream << "cudaMemcpy(dy, y, " << y_size << ", cudaMemcpyHostToDevice);" << std::endl;
×
294

NEW
295
    stream << "cublasStatus_t err;" << std::endl;
×
NEW
296
    stream << "cublasHandle_t handle;" << std::endl;
×
NEW
297
    stream << "err = cublasCreate(&handle);" << std::endl;
×
NEW
298
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
NEW
299
    stream.setIndent(stream.indent() + 4);
×
NEW
300
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
NEW
301
    stream.setIndent(stream.indent() - 4);
×
NEW
302
    stream << "}" << std::endl;
×
NEW
303
    stream << "err = cublas" << type2 << "dot(handle, " << this->language_extension_.expression(dot_node.n())
×
NEW
304
           << ", dx, " << this->language_extension_.expression(dot_node.incx()) << ", dy, "
×
NEW
305
           << this->language_extension_.expression(dot_node.incy()) << ", &_out);" << std::endl;
×
NEW
306
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
NEW
307
    stream.setIndent(stream.indent() + 4);
×
NEW
308
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
NEW
309
    stream.setIndent(stream.indent() - 4);
×
NEW
310
    stream << "}" << std::endl;
×
NEW
311
    stream << "err = cublasDestroy(handle);" << std::endl;
×
NEW
312
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
NEW
313
    stream.setIndent(stream.indent() + 4);
×
NEW
314
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
NEW
315
    stream.setIndent(stream.indent() - 4);
×
NEW
316
    stream << "}" << std::endl;
×
317

NEW
318
    stream << "cudaFree(dx);" << std::endl;
×
NEW
319
    stream << "cudaFree(dy);" << std::endl;
×
NEW
320
}
×
321

NEW
322
DotNodeDispatcher_CUBLASWithoutTransfers::DotNodeDispatcher_CUBLASWithoutTransfers(
×
323
    codegen::LanguageExtension& language_extension,
324
    const Function& function,
325
    const data_flow::DataFlowGraph& data_flow_graph,
326
    const DotNode& node
327
)
NEW
328
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
329

NEW
330
void DotNodeDispatcher_CUBLASWithoutTransfers::dispatch_code(
×
331
    codegen::PrettyPrinter& stream,
332
    codegen::PrettyPrinter& globals_stream,
333
    codegen::CodeSnippetFactory& library_snippet_factory
334
) {
NEW
335
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
336

NEW
337
    globals_stream << "#include <cuda.h>" << std::endl;
×
NEW
338
    globals_stream << "#include <cublas_v2.h>" << std::endl;
×
339

NEW
340
    stream << "cublasStatus_t err;" << std::endl;
×
NEW
341
    stream << "cublasHandle_t handle;" << std::endl;
×
NEW
342
    stream << "err = cublasCreate(&handle);" << std::endl;
×
NEW
343
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
NEW
344
    stream.setIndent(stream.indent() + 4);
×
NEW
345
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
NEW
346
    stream.setIndent(stream.indent() - 4);
×
NEW
347
    stream << "}" << std::endl;
×
NEW
348
    stream << "err = cublas";
×
NEW
349
    switch (dot_node.precision()) {
×
350
        case s:
NEW
351
            stream << "S";
×
NEW
352
            break;
×
353
        case d:
NEW
354
            stream << "D";
×
NEW
355
            break;
×
356
        default:
NEW
357
            throw std::runtime_error("Invalid precision for CUBLAS DOT node");
×
358
    }
NEW
359
    stream << "dot(handle, " << this->language_extension_.expression(dot_node.n()) << ", x, "
×
NEW
360
           << this->language_extension_.expression(dot_node.incx()) << ", y, "
×
NEW
361
           << this->language_extension_.expression(dot_node.incy()) << ", &_out);" << std::endl;
×
NEW
362
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
NEW
363
    stream.setIndent(stream.indent() + 4);
×
NEW
364
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
NEW
365
    stream.setIndent(stream.indent() - 4);
×
NEW
366
    stream << "}" << std::endl;
×
NEW
367
    stream << "err = cublasDestroy(handle);" << std::endl;
×
NEW
368
    stream << "if (err != CUBLAS_STATUS_SUCCESS) {" << std::endl;
×
NEW
369
    stream.setIndent(stream.indent() + 4);
×
NEW
370
    stream << this->language_extension_.external_prefix() << "exit(1);" << std::endl;
×
NEW
371
    stream.setIndent(stream.indent() - 4);
×
NEW
372
    stream << "}" << std::endl;
×
UNCOV
373
}
×
374

375
} // namespace blas
376
} // namespace math
377
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc