• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 26463753889

26 May 2026 05:18PM UTC coverage: 60.864% (-0.02%) from 60.886%
26463753889

Pull #719

github

web-flow
Merge 0b90ddd88 into 707dadcf8
Pull Request #719: Libnode ptr edges

961 of 1749 new or added lines in 52 files covered. (54.95%)

90 existing lines in 29 files now uncovered.

35222 of 57870 relevant lines covered (60.86%)

11043.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.88
/sdfg/src/data_flow/library_nodes/math/blas/dot_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/blas/dot_node.h"
2
#include <stdexcept>
3
#include <string>
4

5
#include "sdfg/analysis/analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7

8
#include "sdfg/analysis/scope_analysis.h"
9
#include "sdfg/symbolic/symbolic.h"
10

11
namespace sdfg {
12
namespace math {
13
namespace blas {
14

15
DotNode::DotNode(
16
    size_t element_id,
17
    const DebugInfo& debug_info,
18
    const graph::Vertex vertex,
19
    data_flow::DataFlowGraph& parent,
20
    const data_flow::ImplementationType& implementation_type,
21
    const BLAS_Precision& precision,
22
    symbolic::Expression n,
23
    symbolic::Expression incx,
24
    symbolic::Expression incy
25
)
26
    : BLASNode(
11✔
27
          element_id,
11✔
28
          debug_info,
11✔
29
          vertex,
11✔
30
          parent,
11✔
31
          LibraryNodeType_DOT,
11✔
32
          {"__out"},
11✔
33
          {"__x", "__y"},
11✔
34
          implementation_type,
11✔
35
          precision
11✔
36
      ),
11✔
37
      n_(n), incx_(incx), incy_(incy) {}
11✔
38

39
symbolic::Expression DotNode::n() const { return this->n_; };
4✔
40

41
symbolic::Expression DotNode::incx() const { return this->incx_; };
2✔
42

43
symbolic::Expression DotNode::incy() const { return this->incy_; };
2✔
44

45
symbolic::SymbolSet DotNode::symbols() const {
×
46
    symbolic::SymbolSet syms;
×
47

48
    for (auto& atom : symbolic::atoms(this->n_)) {
×
49
        syms.insert(atom);
×
50
    }
×
51
    for (auto& atom : symbolic::atoms(this->incx_)) {
×
52
        syms.insert(atom);
×
53
    }
×
54
    for (auto& atom : symbolic::atoms(this->incy_)) {
×
55
        syms.insert(atom);
×
56
    }
×
57

58
    return syms;
×
59
};
×
60

61
void DotNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
62
    this->n_ = symbolic::subs(this->n_, old_expression, new_expression);
×
63
    this->incx_ = symbolic::subs(this->incx_, old_expression, new_expression);
×
64
    this->incy_ = symbolic::subs(this->incy_, old_expression, new_expression);
×
65
};
×
66

67
void DotNode::validate(const Function& function) const { BLASNode::validate(function); }
×
68

69
bool DotNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
3✔
70
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
3✔
71

72
    auto& dataflow = this->get_parent();
3✔
73
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
3✔
74
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
3✔
75
    int index = parent.index(block);
3✔
76
    auto& transition = parent.at(index).second;
3✔
77

78
    const data_flow::Memlet* iedge_x = nullptr;
3✔
79
    const data_flow::Memlet* iedge_y = nullptr;
3✔
80
    for (const auto& iedge : dataflow.in_edges(*this)) {
6✔
81
        if (iedge.dst_conn() == "__x") {
6✔
82
            iedge_x = &iedge;
3✔
83
        } else if (iedge.dst_conn() == "__y") {
3✔
84
            iedge_y = &iedge;
3✔
85
        }
3✔
86
    }
6✔
87

88
    const data_flow::Memlet* oedge_res = nullptr;
3✔
89
    for (const auto& oedge : dataflow.out_edges(*this)) {
3✔
90
        if (oedge.src_conn() == "__out") {
3✔
91
            oedge_res = &oedge;
3✔
92
            break;
3✔
93
        }
3✔
94
    }
3✔
95

96
    // Check if legal
97
    auto& input_node_x = static_cast<const data_flow::AccessNode&>(iedge_x->src());
3✔
98
    auto& input_node_y = static_cast<const data_flow::AccessNode&>(iedge_y->src());
3✔
99
    auto& output_node_res = static_cast<const data_flow::AccessNode&>(oedge_res->dst());
3✔
100
    if (dataflow.in_degree(input_node_x) != 0 || dataflow.in_degree(input_node_y) != 0 ||
3✔
101
        dataflow.out_degree(output_node_res) != 0) {
3✔
102
        return false;
×
103
    }
×
104

105
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
3✔
106

107
    std::string loop_var = builder.find_new_name("_i");
3✔
108
    builder.add_container(loop_var, types::Scalar(types::PrimitiveType::UInt64));
3✔
109

110
    auto loop_indvar = symbolic::symbol(loop_var);
3✔
111
    auto loop_init = symbolic::integer(0);
3✔
112
    auto loop_condition = symbolic::Lt(loop_indvar, this->n_);
3✔
113
    auto loop_update = symbolic::add(loop_indvar, symbolic::integer(1));
3✔
114

115
    auto& loop =
3✔
116
        builder.add_for(new_sequence, loop_indvar, loop_condition, loop_init, loop_update, {}, block.debug_info());
3✔
117
    auto& body = loop.root();
3✔
118

119
    auto& new_block = builder.add_block(body);
3✔
120

121
    auto& res_in = builder.add_access(new_block, output_node_res.data());
3✔
122
    auto& res_out = builder.add_access(new_block, output_node_res.data());
3✔
123
    auto& x = builder.add_access(new_block, input_node_x.data());
3✔
124
    auto& y = builder.add_access(new_block, input_node_y.data());
3✔
125

126
    auto& tasklet = builder.add_tasklet(new_block, data_flow::TaskletCode::fp_fma, "__out", {"_in1", "_in2", "_in3"});
3✔
127

128
    builder.add_computational_memlet(
3✔
129
        new_block,
3✔
130
        x,
3✔
131
        tasklet,
3✔
132
        "_in1",
3✔
133
        {symbolic::mul(loop_indvar, this->incx_)},
3✔
134
        iedge_x->base_type(),
3✔
135
        iedge_x->debug_info()
3✔
136
    );
3✔
137
    builder.add_computational_memlet(
3✔
138
        new_block,
3✔
139
        y,
3✔
140
        tasklet,
3✔
141
        "_in2",
3✔
142
        {symbolic::mul(loop_indvar, this->incy_)},
3✔
143
        iedge_y->base_type(),
3✔
144
        iedge_y->debug_info()
3✔
145
    );
3✔
146
    builder
3✔
147
        .add_computational_memlet(new_block, res_in, tasklet, "_in3", {}, oedge_res->base_type(), oedge_res->debug_info());
3✔
148
    builder.add_computational_memlet(
3✔
149
        new_block, tasklet, "__out", res_out, {}, oedge_res->base_type(), oedge_res->debug_info()
3✔
150
    );
3✔
151

152
    // Clean up
153
    builder.remove_memlet(block, *iedge_x);
3✔
154
    builder.remove_memlet(block, *iedge_y);
3✔
155
    builder.remove_memlet(block, *oedge_res);
3✔
156
    builder.remove_node(block, input_node_x);
3✔
157
    builder.remove_node(block, input_node_y);
3✔
158
    builder.remove_node(block, output_node_res);
3✔
159
    builder.remove_node(block, *this);
3✔
160
    builder.remove_child(parent, index + 1);
3✔
161

162
    return true;
3✔
163
}
3✔
164

165
symbolic::Expression DotNode::flop() const {
×
166
    auto muls = this->n_;
×
167
    auto adds = symbolic::sub(this->n_, symbolic::one());
×
168
    return symbolic::add(muls, adds);
×
169
}
×
170

171
std::unique_ptr<data_flow::DataFlowNode> DotNode::
172
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
173
    auto node_clone = std::unique_ptr<DotNode>(new DotNode(
×
174
        element_id,
×
175
        this->debug_info(),
×
176
        vertex,
×
177
        parent,
×
178
        this->implementation_type_,
×
179
        this->precision_,
×
180
        this->n_,
×
181
        this->incx_,
×
182
        this->incy_
×
183
    ));
×
184
    return std::move(node_clone);
×
185
}
×
186

NEW
187
data_flow::PointerAccessType DotNode::pointer_access_type(int input_idx) const {
×
NEW
188
    if (input_idx == 0) {
×
NEW
189
        return data_flow::PointerAccessMeta::create_read_only(symbolic::mul(n_, incx_), true);
×
NEW
190
    } else if (input_idx == 1) {
×
NEW
191
        return data_flow::PointerAccessMeta::create_read_only(symbolic::mul(n_, incy_), true);
×
NEW
192
    } else {
×
NEW
193
        return BLASNode::pointer_access_type(input_idx);
×
NEW
194
    }
×
NEW
195
}
×
196

197
nlohmann::json DotNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
198
    const DotNode& gemm_node = static_cast<const DotNode&>(library_node);
×
199
    nlohmann::json j;
×
200

201
    serializer::JSONSerializer serializer;
×
202
    j["code"] = gemm_node.code().value();
×
203
    j["precision"] = gemm_node.precision();
×
204
    j["n"] = serializer.expression(gemm_node.n());
×
205
    j["incx"] = serializer.expression(gemm_node.incx());
×
206
    j["incy"] = serializer.expression(gemm_node.incy());
×
207

208
    return j;
×
209
}
×
210

211
data_flow::LibraryNode& DotNodeSerializer::deserialize(
212
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
213
) {
×
214
    // Assertions for required fields
215
    assert(j.contains("element_id"));
×
216
    assert(j.contains("code"));
×
217
    assert(j.contains("debug_info"));
×
218

219
    auto code = j["code"].get<std::string>();
×
220
    if (code != LibraryNodeType_DOT.value()) {
×
221
        throw std::runtime_error("Invalid library node code");
×
222
    }
×
223

224
    // Extract debug info using JSONSerializer
225
    sdfg::serializer::JSONSerializer serializer;
×
226
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
227

228
    auto precision = j.at("precision").get<BLAS_Precision>();
×
229
    auto n = symbolic::parse(j.at("n"));
×
230
    auto incx = symbolic::parse(j.at("incx"));
×
231
    auto incy = symbolic::parse(j.at("incy"));
×
232

233
    auto implementation_type = j.at("implementation_type").get<std::string>();
×
234

235
    return builder.add_library_node<DotNode>(parent, debug_info, implementation_type, precision, n, incx, incy);
×
236
}
×
237

238
DotNodeDispatcher_BLAS::DotNodeDispatcher_BLAS(
239
    codegen::LanguageExtension& language_extension,
240
    const Function& function,
241
    const data_flow::DataFlowGraph& data_flow_graph,
242
    const DotNode& node
243
)
244
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
245

246
void DotNodeDispatcher_BLAS::dispatch_code_with_edges(
247
    codegen::CodegenOutput& out,
248
    std::vector<codegen::DispatchInput>& inputs,
249
    std::vector<codegen::DispatchOutput>& outputs
250
) {
×
251
    auto& dot_node = static_cast<const DotNode&>(this->node_);
×
252

253
    sdfg::types::Scalar base_type(types::PrimitiveType::Void);
×
254
    BLAS_Precision precision = dot_node.precision();
×
255
    switch (precision) {
×
256
        case BLAS_Precision::h:
×
257
            base_type = types::Scalar(types::PrimitiveType::Half);
×
258
            break;
×
259
        case BLAS_Precision::s:
×
260
            base_type = types::Scalar(types::PrimitiveType::Float);
×
261
            break;
×
262
        case BLAS_Precision::d:
×
263
            base_type = types::Scalar(types::PrimitiveType::Double);
×
264
            break;
×
265
        default:
×
266
            throw std::runtime_error("Invalid BLAS_Precision value");
×
267
    }
×
268

NEW
269
    out.library_snippet_factory.require_dependency(BLASLibDependency::instance());
×
270

NEW
271
    auto& output = outputs.at(0);
×
NEW
272
    pre_allocate_output(out, output, dot_node.output(0));
×
273

NEW
274
    out.stream << *output.local_name << " = ";
×
NEW
275
    out.stream << "cblas_" << BLAS_Precision_to_string(precision) << "dot(";
×
NEW
276
    out.stream.changeIndent(+4);
×
NEW
277
    out.stream << this->language_extension_.expression(dot_node.n());
×
NEW
278
    out.stream << ", ";
×
NEW
279
    out.stream << inputs.at(0).expr;
×
NEW
280
    out.stream << ", ";
×
NEW
281
    out.stream << this->language_extension_.expression(dot_node.incx());
×
NEW
282
    out.stream << ", ";
×
NEW
283
    out.stream << inputs.at(1).expr;
×
NEW
284
    out.stream << ", ";
×
NEW
285
    out.stream << this->language_extension_.expression(dot_node.incy());
×
NEW
286
    out.stream.changeIndent(-4);
×
NEW
287
    out.stream << ");" << std::endl;
×
UNCOV
288
}
×
289

290

291
} // namespace blas
292
} // namespace math
293
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc