• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 16779684622

06 Aug 2025 02:21PM UTC coverage: 64.3% (-1.0%) from 65.266%
16779684622

push

github

web-flow
Merge pull request #172 from daisytuner/opaque-pointers

Opaque pointers, typed memlets, untyped tasklet connectors

330 of 462 new or added lines in 38 files covered. (71.43%)

382 existing lines in 30 files now uncovered.

8865 of 13787 relevant lines covered (64.3%)

116.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

34.95
/src/data_flow/library_nodes/math/blas/dot.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/dot.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
#include "sdfg/analysis/scope_analysis.h"
7

8
namespace sdfg {
9
namespace math {
10
namespace blas {
11

12
DotNode::DotNode(
1✔
13
    size_t element_id,
14
    const DebugInfo& debug_info,
15
    const graph::Vertex vertex,
16
    data_flow::DataFlowGraph& parent,
17
    const data_flow::ImplementationType& implementation_type,
18
    const BLAS_Precision& precision,
19
    symbolic::Expression n,
20
    symbolic::Expression incx,
21
    symbolic::Expression incy
22
)
23
    : MathNode(element_id, debug_info, vertex, parent, LibraryNodeType_DOT, {"res"}, {"x", "y"}, implementation_type),
1✔
24
      precision_(precision), n_(n), incx_(incx), incy_(incy) {}
1✔
25

26
BLAS_Precision DotNode::precision() const { return this->precision_; };
×
27

28
symbolic::Expression DotNode::n() const { return this->n_; };
×
29

30
symbolic::Expression DotNode::incx() const { return this->incx_; };
×
31

32
symbolic::Expression DotNode::incy() const { return this->incy_; };
×
33

34
void DotNode::validate(const Function& function) const {
×
35
    auto& graph = this->get_parent();
×
36

37
    if (graph.in_degree(*this) != this->inputs_.size()) {
×
38
        throw InvalidSDFGException("DotNode must have " + std::to_string(this->inputs_.size()) + " inputs");
×
39
    }
40
    if (graph.out_degree(*this) != 1) {
×
41
        throw InvalidSDFGException("DotNode must have 1 output");
×
42
    }
43

44
    std::unordered_map<std::string, const data_flow::Memlet*> memlets;
×
45
    for (auto& input : this->inputs_) {
×
46
        bool found = false;
×
47
        for (auto& iedge : graph.in_edges(*this)) {
×
48
            if (iedge.dst_conn() == input) {
×
49
                found = true;
×
50
                memlets[input] = &iedge;
×
51
                break;
×
52
            }
53
        }
54
        if (!found) {
×
55
            throw InvalidSDFGException("DotNode input " + input + " not found");
×
56
        }
57
    }
58

59
    auto& oedge = *graph.out_edges(*this).begin();
×
60
    if (oedge.src_conn() != this->outputs_.at(0)) {
×
61
        throw InvalidSDFGException("DotNode output " + this->outputs_.at(0) + " not found");
×
62
    }
63

64
    auto& x_memlet = memlets.at("x");
×
65
    auto& x_subset_begin = x_memlet->begin_subset();
×
66
    auto& x_subset_end = x_memlet->end_subset();
×
67
    if (x_subset_begin.size() != 1) {
×
68
        throw InvalidSDFGException("DotNode input x must have 1 dimensions");
×
69
    }
70

71
    auto& y_memlet = memlets.at("y");
×
72
    auto& y_subset_begin = y_memlet->begin_subset();
×
73
    auto& y_subset_end = y_memlet->end_subset();
×
74
    if (y_subset_begin.size() != 1) {
×
75
        throw InvalidSDFGException("DotNode input y must have 1 dimensions");
×
76
    }
77
}
×
78

79
bool DotNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
80
    auto& dataflow = this->get_parent();
1✔
81
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
82

83
    auto& scope_analyisis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
84
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analyisis.parent_scope(&block));
1✔
85

86
    const data_flow::Memlet* iedge_x = nullptr;
1✔
87
    const data_flow::Memlet* iedge_y = nullptr;
1✔
88
    for (const auto& iedge : dataflow.in_edges(*this)) {
3✔
89
        if (iedge.dst_conn() == "x") {
2✔
90
            iedge_x = &iedge;
1✔
91
        } else if (iedge.dst_conn() == "y") {
2✔
92
            iedge_y = &iedge;
1✔
93
        }
1✔
94
    }
95

96
    const data_flow::Memlet* oedge_res = nullptr;
1✔
97
    for (const auto& oedge : dataflow.out_edges(*this)) {
1✔
98
        if (oedge.src_conn() == "res") {
1✔
99
            oedge_res = &oedge;
1✔
100
            break;
1✔
101
        }
102
    }
103

104
    // Check if legal
105
    auto& input_node_x = static_cast<const data_flow::AccessNode&>(iedge_x->src());
1✔
106
    auto& input_node_y = static_cast<const data_flow::AccessNode&>(iedge_y->src());
1✔
107
    auto& output_node_res = static_cast<const data_flow::AccessNode&>(oedge_res->dst());
1✔
108
    if (dataflow.in_degree(input_node_x) != 0 || dataflow.in_degree(input_node_y) != 0 ||
1✔
109
        dataflow.out_degree(output_node_res) != 0) {
1✔
110
        return false;
×
111
    }
112

113
    auto& new_sequence = builder.add_sequence_before(parent, block, block.debug_info()).first;
1✔
114

115
    std::string loop_var = builder.find_new_name("_i");
1✔
116
    builder.add_container(loop_var, types::Scalar(types::PrimitiveType::UInt64));
1✔
117

118
    auto loop_indvar = symbolic::symbol(loop_var);
1✔
119
    auto loop_init = symbolic::integer(0);
1✔
120
    auto loop_condition = symbolic::Lt(loop_indvar, this->n_);
1✔
121
    auto loop_update = symbolic::add(loop_indvar, symbolic::integer(1));
1✔
122

123
    auto& loop =
1✔
124
        builder.add_for(new_sequence, loop_indvar, loop_condition, loop_init, loop_update, {}, block.debug_info());
1✔
125
    auto& body = loop.root();
1✔
126

127
    auto& new_block = builder.add_block(body);
1✔
128

129
    auto& res_in = builder.add_access(new_block, output_node_res.data());
1✔
130
    auto& res_out = builder.add_access(new_block, output_node_res.data());
1✔
131
    auto& x = builder.add_access(new_block, input_node_x.data());
1✔
132
    auto& y = builder.add_access(new_block, input_node_y.data());
1✔
133

134
    auto& tasklet = builder.add_tasklet(new_block, data_flow::TaskletCode::fma, "_out", {"_in1", "_in2", "_in3"});
1✔
135

136
    builder.add_computational_memlet(
2✔
137
        new_block,
1✔
138
        x,
1✔
139
        tasklet,
1✔
140
        "_in1",
1✔
141
        {symbolic::mul(loop_indvar, this->incx_)},
1✔
142
        iedge_x->base_type(),
1✔
143
        iedge_x->debug_info()
1✔
144
    );
145
    builder.add_computational_memlet(
2✔
146
        new_block,
1✔
147
        y,
1✔
148
        tasklet,
1✔
149
        "_in2",
1✔
150
        {symbolic::mul(loop_indvar, this->incy_)},
1✔
151
        iedge_y->base_type(),
1✔
152
        iedge_y->debug_info()
1✔
153
    );
154
    builder
2✔
155
        .add_computational_memlet(new_block, res_in, tasklet, "_in3", {}, oedge_res->base_type(), oedge_res->debug_info());
1✔
156
    builder.add_computational_memlet(
2✔
157
        new_block, tasklet, "_out", res_out, {}, oedge_res->base_type(), oedge_res->debug_info()
1✔
158
    );
159

160
    // Clean up
161
    builder.remove_memlet(block, *iedge_x);
1✔
162
    builder.remove_memlet(block, *iedge_y);
1✔
163
    builder.remove_memlet(block, *oedge_res);
1✔
164
    builder.remove_node(block, input_node_x);
1✔
165
    builder.remove_node(block, input_node_y);
1✔
166
    builder.remove_node(block, output_node_res);
1✔
167
    builder.remove_node(block, *this);
1✔
168
    builder.remove_child(parent, block);
1✔
169

170
    return true;
1✔
171
}
1✔
172

173
std::unique_ptr<data_flow::DataFlowNode> DotNode::
174
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
175
    auto node_clone = std::unique_ptr<DotNode>(new DotNode(
×
176
        element_id,
×
177
        this->debug_info(),
×
178
        vertex,
×
179
        parent,
×
180
        this->implementation_type_,
×
181
        this->precision_,
×
182
        this->n_,
×
183
        this->incx_,
×
184
        this->incy_
×
185
    ));
186
    return std::move(node_clone);
×
187
}
×
188

189
nlohmann::json DotNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
190
    const DotNode& gemm_node = static_cast<const DotNode&>(library_node);
×
191
    nlohmann::json j;
×
192

193
    serializer::JSONSerializer serializer;
×
194
    j["code"] = gemm_node.code().value();
×
195
    j["precision"] = gemm_node.precision();
×
196
    j["n"] = serializer.expression(gemm_node.n());
×
197
    j["incx"] = serializer.expression(gemm_node.incx());
×
198
    j["incy"] = serializer.expression(gemm_node.incy());
×
199

200
    return j;
×
201
}
×
202

203
data_flow::LibraryNode& DotNodeSerializer::deserialize(
×
204
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
205
) {
206
    // Assertions for required fields
207
    assert(j.contains("element_id"));
×
208
    assert(j.contains("code"));
×
209
    assert(j.contains("debug_info"));
×
210

211
    auto code = j["code"].get<std::string>();
×
212
    if (code != LibraryNodeType_DOT.value()) {
×
213
        throw std::runtime_error("Invalid library node code");
×
214
    }
215

216
    // Extract debug info using JSONSerializer
217
    sdfg::serializer::JSONSerializer serializer;
×
218
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
219

220
    auto precision = j.at("precision").get<BLAS_Precision>();
×
221
    auto n = SymEngine::Expression(j.at("n"));
×
222
    auto incx = SymEngine::Expression(j.at("incx"));
×
223
    auto incy = SymEngine::Expression(j.at("incy"));
×
224

225
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
226

227
    return builder.add_library_node<DotNode>(parent, debug_info, implementation_type, precision, n, incx, incy);
×
228
}
×
229

230
DotNodeDispatcher_BLAS::DotNodeDispatcher_BLAS(
×
231
    codegen::LanguageExtension& language_extension,
232
    const Function& function,
233
    const data_flow::DataFlowGraph& data_flow_graph,
234
    const DotNode& node
235
)
236
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
237

238
void DotNodeDispatcher_BLAS::dispatch(
×
239
    codegen::PrettyPrinter& stream,
240
    codegen::PrettyPrinter& globals_stream,
241
    codegen::CodeSnippetFactory& library_snippet_factory
242
) {
UNCOV
243
    stream << "{" << std::endl;
×
244
    stream.setIndent(stream.indent() + 4);
×
245

UNCOV
246
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
247

248
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
UNCOV
249
    switch (dot_node.precision()) {
×
250
        case BLAS_Precision::h:
251
            base_type = types::Scalar(types::PrimitiveType::Half);
×
UNCOV
252
            break;
×
253
        case BLAS_Precision::s:
254
            base_type = types::Scalar(types::PrimitiveType::Float);
×
UNCOV
255
            break;
×
256
        case BLAS_Precision::d:
UNCOV
257
            base_type = types::Scalar(types::PrimitiveType::Double);
×
UNCOV
258
            break;
×
259
        default:
260
            throw std::runtime_error("Invalid BLAS_Precision value");
×
261
    }
262

263
    auto& graph = this->node_.get_parent();
×
UNCOV
264
    for (auto& iedge : graph.in_edges(this->node_)) {
×
265
        auto& access_node = static_cast<const data_flow::AccessNode&>(iedge.src());
×
266
        std::string name = access_node.data();
×
267
        auto& type = this->function_.type(name);
×
268

269
        stream << this->language_extension_.declaration(iedge.dst_conn(), type);
×
270
        stream << " = " << name << ";" << std::endl;
×
271
    }
×
UNCOV
272
    for (auto& oedge : graph.out_edges(this->node_)) {
×
273
        auto& access_node = static_cast<const data_flow::AccessNode&>(oedge.dst());
×
274
        std::string name = access_node.data();
×
275
        auto& type = this->function_.type(name);
×
276

277
        stream << this->language_extension_.declaration(oedge.src_conn(), type);
×
278
        stream << ";" << std::endl;
×
279
    }
×
280

281
    std::string res_name = this->node_.outputs().at(0);
×
282
    stream << res_name << " = ";
×
283
    stream << "cblas_" << BLAS_Precision_to_string(dot_node.precision()) << "dot(";
×
284
    stream.setIndent(stream.indent() + 4);
×
285
    stream << this->language_extension_.expression(dot_node.n());
×
286
    stream << ", ";
×
287
    stream << "x";
×
288
    stream << ", ";
×
289
    stream << this->language_extension_.expression(dot_node.incx());
×
290
    stream << ", ";
×
291
    stream << "y";
×
UNCOV
292
    stream << ", ";
×
293
    stream << this->language_extension_.expression(dot_node.incy());
×
294
    stream.setIndent(stream.indent() - 4);
×
295
    stream << ");" << std::endl;
×
296

297
    for (auto& oedge : graph.out_edges(this->node_)) {
×
298
        auto& access_node = static_cast<const data_flow::AccessNode&>(oedge.dst());
×
UNCOV
299
        std::string name = access_node.data();
×
300
        auto& type = this->function_.type(name);
×
301
        stream << name << " = " << oedge.src_conn() << ";" << std::endl;
×
302
    }
×
303

304
    stream.setIndent(stream.indent() - 4);
×
UNCOV
305
    stream << "}" << std::endl;
×
UNCOV
306
}
×
307

UNCOV
308
DotNodeDispatcher_CUBLAS::DotNodeDispatcher_CUBLAS(
×
309
    codegen::LanguageExtension& language_extension,
310
    const Function& function,
311
    const data_flow::DataFlowGraph& data_flow_graph,
312
    const DotNode& node
313
)
314
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
315

UNCOV
316
void DotNodeDispatcher_CUBLAS::dispatch(
×
317
    codegen::PrettyPrinter& stream,
318
    codegen::PrettyPrinter& globals_stream,
319
    codegen::CodeSnippetFactory& library_snippet_factory
320
) {
UNCOV
321
    throw std::runtime_error("DotNodeDispatcher_CUBLAS not implemented");
×
UNCOV
322
}
×
323

324
} // namespace blas
325
} // namespace math
326
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc