• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20585279183

29 Dec 2025 11:48PM UTC coverage: 39.581% (-0.8%) from 40.359%
20585279183

push

github

web-flow
Merge pull request #412 from daisytuner/mean-std-nodes

adds mean/std library nodes

14647 of 48066 branches covered (30.47%)

Branch coverage included in aggregate %.

225 of 622 new or added lines in 14 files covered. (36.17%)

41 existing lines in 1 file now uncovered.

12489 of 20493 relevant lines covered (60.94%)

87.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

50.91
/src/data_flow/library_nodes/math/tensor/reduce_ops/std_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/std_node.h"
2
#include "sdfg/analysis/scope_analysis.h"
3
#include "sdfg/builder/structured_sdfg_builder.h"
4
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/pow_node.h"
5
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/sqrt_node.h"
6
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/sub_node.h"
7
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/mean_node.h"
8
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
9
#include "sdfg/structured_control_flow/block.h"
10
#include "sdfg/types/scalar.h"
11
#include "sdfg/types/utils.h"
12

13
namespace sdfg {
14
namespace math {
15
namespace tensor {
16

17
StdNode::StdNode(
2✔
18
    size_t element_id,
19
    const DebugInfo& debug_info,
20
    const graph::Vertex vertex,
21
    data_flow::DataFlowGraph& parent,
22
    const std::vector<symbolic::Expression>& shape,
23
    const std::vector<int64_t>& axes,
24
    bool keepdims
25
)
26
    : ReduceNode(element_id, debug_info, vertex, parent, LibraryNodeType_Std, shape, axes, keepdims) {}
2✔
27

28
bool StdNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
29
    auto& dataflow = this->get_parent();
2✔
30
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
2✔
31

32
    if (dataflow.in_degree(*this) != 1 || dataflow.out_degree(*this) != 1) {
2!
NEW
33
        return false;
×
34
    }
35

36
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
2✔
37
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
2✔
38
    int index = parent.index(block);
2✔
39
    auto& transition = parent.at(index).second;
2✔
40

41
    auto& in_edge = *dataflow.in_edges(*this).begin();
2✔
42
    auto& out_edge = *dataflow.out_edges(*this).begin();
2✔
43
    auto& in_node = static_cast<data_flow::AccessNode&>(in_edge.src());
2✔
44
    auto& out_node = static_cast<data_flow::AccessNode&>(out_edge.dst());
2✔
45

46
    // Calculate output shape
47
    std::vector<symbolic::Expression> output_shape;
2✔
48
    std::vector<int64_t> sorted_axes = axes_;
2!
49
    std::sort(sorted_axes.begin(), sorted_axes.end());
2!
50

51
    for (size_t i = 0; i < shape_.size(); ++i) {
4✔
52
        bool is_axis = false;
2✔
53
        for (auto axis : sorted_axes) {
2!
54
            if (axis == (int64_t) i) {
2!
55
                is_axis = true;
2✔
56
                break;
2✔
57
            }
58
        }
59

60
        if (is_axis) {
2!
61
            if (keepdims_) {
2!
NEW
62
                output_shape.push_back(symbolic::one());
×
NEW
63
            }
×
64
        } else {
2✔
NEW
65
            output_shape.push_back(shape_[i]);
×
66
        }
67
    }
2✔
68

69
    std::string tmp_x2_name = builder.find_new_name("_std_x2");
2!
70
    builder.add_container(tmp_x2_name, in_edge.base_type());
2!
71
    std::string tmp_mean_x2_name = builder.find_new_name("_std_mean_x2");
2!
72
    builder.add_container(tmp_mean_x2_name, out_edge.base_type());
2!
73
    std::string tmp_mean_x_name = builder.find_new_name("_std_mean_x");
2!
74
    builder.add_container(tmp_mean_x_name, out_edge.base_type());
2!
75

76
    if (in_edge.base_type().type_id() == types::TypeID::Pointer) {
2!
77
        auto& pointee_type = static_cast<const types::Pointer&>(in_edge.base_type()).pointee_type();
2!
78
        symbolic::Expression bytes_in = types::get_type_size(pointee_type, false);
2!
79
        for (auto& dim : this->shape_) {
4✔
80
            bytes_in = symbolic::mul(dim, bytes_in);
2!
81
        }
82

83
        {
84
            auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
2!
85
            auto& tmp_x2_name_access = builder.add_access(alloc_block, tmp_x2_name);
2!
86
            auto& tmp_x2_name_malloc_node =
2✔
87
                builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), bytes_in);
2!
88
            builder.add_computational_memlet(
4!
89
                alloc_block,
2✔
90
                tmp_x2_name_malloc_node,
2✔
91
                "_ret",
2!
92
                tmp_x2_name_access,
2✔
93
                {},
2✔
94
                in_edge.base_type(),
2!
95
                this->debug_info()
2!
96
            );
97
        }
98
    }
2✔
99

100
    if (out_edge.base_type().type_id() == types::TypeID::Pointer) {
2!
NEW
101
        auto& pointee_type = static_cast<const types::Pointer&>(out_edge.base_type()).pointee_type();
×
NEW
102
        symbolic::Expression bytes_out = types::get_type_size(pointee_type, false);
×
NEW
103
        for (auto& dim : output_shape) {
×
NEW
104
            bytes_out = symbolic::mul(dim, bytes_out);
×
105
        }
106

107
        {
NEW
108
            auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
109
            auto& tmp_mean_x2_name_access = builder.add_access(alloc_block, tmp_mean_x2_name);
×
NEW
110
            auto& tmp_mean_x2_name_malloc_node =
×
NEW
111
                builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), bytes_out);
×
NEW
112
            builder.add_computational_memlet(
×
NEW
113
                alloc_block,
×
NEW
114
                tmp_mean_x2_name_malloc_node,
×
NEW
115
                "_ret",
×
NEW
116
                tmp_mean_x2_name_access,
×
NEW
117
                {},
×
NEW
118
                in_edge.base_type(),
×
NEW
119
                this->debug_info()
×
120
            );
121
        }
122

123
        {
NEW
124
            auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
125
            auto& tmp_mean_x_name_access = builder.add_access(alloc_block, tmp_mean_x_name);
×
NEW
126
            auto& tmp_mean_x_name_malloc_node =
×
NEW
127
                builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), bytes_out);
×
NEW
128
            builder.add_computational_memlet(
×
NEW
129
                alloc_block,
×
NEW
130
                tmp_mean_x_name_malloc_node,
×
NEW
131
                "_ret",
×
NEW
132
                tmp_mean_x_name_access,
×
NEW
133
                {},
×
NEW
134
                in_edge.base_type(),
×
NEW
135
                this->debug_info()
×
136
            );
137
        }
NEW
138
    }
×
139

140
    // 1. X^2
141
    auto& pow_block = builder.add_block_before(parent, block, {}, this->debug_info());
2!
142
    auto& pow_in_node = builder.add_access(pow_block, in_node.data(), this->debug_info());
2!
143
    auto& pow_out_node = builder.add_access(pow_block, tmp_x2_name, this->debug_info());
2!
144

145
    auto& pow_node_1 = builder.add_library_node<PowNode>(pow_block, this->debug_info(), shape_);
2!
146
    auto& const_2 =
2✔
147
        builder.add_constant(pow_block, "2", types::Scalar(types::PrimitiveType::Int64), this->debug_info());
2!
148

149
    builder
4✔
150
        .add_computational_memlet(pow_block, pow_in_node, pow_node_1, "A", {}, in_edge.base_type(), this->debug_info());
2!
151
    builder.add_computational_memlet(
4!
152
        pow_block, const_2, pow_node_1, "B", {}, types::Scalar(types::PrimitiveType::Int64), this->debug_info()
2!
153
    );
154
    builder
4✔
155
        .add_computational_memlet(pow_block, pow_node_1, "C", pow_out_node, {}, in_edge.base_type(), this->debug_info());
2!
156

157
    // 2. Mean(X^2)
158
    auto& mean_x2_block = builder.add_block_before(parent, block, {}, this->debug_info());
2!
159
    auto& mean_x2_in_node = builder.add_access(mean_x2_block, tmp_x2_name, this->debug_info());
2!
160
    auto& mean_x2_out_node = builder.add_access(mean_x2_block, tmp_mean_x2_name, this->debug_info());
2!
161

162
    auto& mean_node_1 = builder.add_library_node<MeanNode>(mean_x2_block, this->debug_info(), shape_, axes_, keepdims_);
2!
163
    builder.add_computational_memlet(
4!
164
        mean_x2_block, mean_x2_in_node, mean_node_1, "X", {}, in_edge.base_type(), this->debug_info()
2!
165
    );
166
    builder.add_computational_memlet(
4!
167
        mean_x2_block, mean_node_1, "Y", mean_x2_out_node, {}, out_edge.base_type(), this->debug_info()
2!
168
    );
169

170
    // 3. Mean(X)
171
    auto& mean_x_block = builder.add_block_before(parent, block, {}, this->debug_info());
2!
172
    auto& mean_x_in_node = builder.add_access(mean_x_block, in_node.data(), this->debug_info());
2!
173
    auto& mean_x_out_node = builder.add_access(mean_x_block, tmp_mean_x_name, this->debug_info());
2!
174

175
    auto& mean_node_2 = builder.add_library_node<MeanNode>(mean_x_block, this->debug_info(), shape_, axes_, keepdims_);
2!
176
    builder.add_computational_memlet(
4!
177
        mean_x_block, mean_x_in_node, mean_node_2, "X", {}, in_edge.base_type(), this->debug_info()
2!
178
    );
179
    builder.add_computational_memlet(
4!
180
        mean_x_block, mean_node_2, "Y", mean_x_out_node, {}, out_edge.base_type(), this->debug_info()
2!
181
    );
182

183
    // 4. Mean(X)^2
184
    auto& pow_mean_x_block = builder.add_block_before(parent, block, {}, this->debug_info());
2!
185
    auto& pow_mean_x_in_node = builder.add_access(pow_mean_x_block, tmp_mean_x_name, this->debug_info());
2!
186
    auto& pow_mean_x_out_node = builder.add_access(pow_mean_x_block, tmp_mean_x_name, this->debug_info());
2!
187

188
    auto& pow_node_2 = builder.add_library_node<PowNode>(pow_mean_x_block, this->debug_info(), output_shape);
2!
189
    auto& const_2_2 =
2✔
190
        builder.add_constant(pow_mean_x_block, "2", types::Scalar(types::PrimitiveType::Int64), this->debug_info());
2!
191

192
    builder.add_computational_memlet(
4!
193
        pow_mean_x_block, pow_mean_x_in_node, pow_node_2, "A", {}, out_edge.base_type(), this->debug_info()
2!
194
    );
195
    builder.add_computational_memlet(
4!
196
        pow_mean_x_block, const_2_2, pow_node_2, "B", {}, types::Scalar(types::PrimitiveType::Int64), this->debug_info()
2!
197
    );
198
    builder.add_computational_memlet(
4!
199
        pow_mean_x_block, pow_node_2, "C", pow_mean_x_out_node, {}, out_edge.base_type(), this->debug_info()
2!
200
    );
201

202
    // 5. Mean(X^2) - Mean(X)^2
203
    auto& sub_block = builder.add_block_before(parent, block, {}, this->debug_info());
2!
204
    auto& sub_in1_node = builder.add_access(sub_block, tmp_mean_x2_name, this->debug_info());
2!
205
    auto& sub_in2_node = builder.add_access(sub_block, tmp_mean_x_name, this->debug_info());
2!
206
    auto& sub_out_node = builder.add_access(sub_block, out_node.data(), this->debug_info());
2!
207

208
    auto& sub_node = builder.add_library_node<SubNode>(sub_block, this->debug_info(), output_shape);
2!
209
    builder
4✔
210
        .add_computational_memlet(sub_block, sub_in1_node, sub_node, "A", {}, out_edge.base_type(), this->debug_info());
2!
211
    builder
4✔
212
        .add_computational_memlet(sub_block, sub_in2_node, sub_node, "B", {}, out_edge.base_type(), this->debug_info());
2!
213
    builder
4✔
214
        .add_computational_memlet(sub_block, sub_node, "C", sub_out_node, {}, out_edge.base_type(), this->debug_info());
2!
215

216
    // 6. Sqrt(...)
217
    auto& sqrt_block = builder.add_block_before(parent, block, transition.assignments(), this->debug_info());
2!
218
    auto& sqrt_in_node = builder.add_access(sqrt_block, out_node.data(), this->debug_info());
2!
219
    auto& sqrt_out_node = builder.add_access(sqrt_block, out_node.data(), this->debug_info());
2!
220

221
    auto& sqrt_node = builder.add_library_node<SqrtNode>(sqrt_block, this->debug_info(), output_shape);
2!
222
    builder
4✔
223
        .add_computational_memlet(sqrt_block, sqrt_in_node, sqrt_node, "X", {}, out_edge.base_type(), this->debug_info());
2!
224
    builder
4✔
225
        .add_computational_memlet(sqrt_block, sqrt_node, "Y", sqrt_out_node, {}, out_edge.base_type(), this->debug_info());
2!
226

227
    // Cleanup
228
    builder.remove_memlet(block, in_edge);
2!
229
    builder.remove_memlet(block, out_edge);
2!
230
    builder.remove_node(block, in_node);
2!
231
    builder.remove_node(block, out_node);
2!
232
    builder.remove_node(block, *this);
2!
233

234
    int last_index = parent.index(block);
2!
235
    builder.remove_child(parent, last_index);
2!
236

237
    return true;
2✔
238
}
2✔
239

NEW
240
bool StdNode::expand_reduction(
×
241
    builder::StructuredSDFGBuilder& builder,
242
    analysis::AnalysisManager& analysis_manager,
243
    structured_control_flow::Sequence& body,
244
    const std::string& input_name,
245
    const std::string& output_name,
246
    const types::IType& input_type,
247
    const types::IType& output_type,
248
    const data_flow::Subset& input_subset,
249
    const data_flow::Subset& output_subset
250
) {
NEW
251
    throw std::runtime_error("StdNode::expand_reduction should not be called");
×
NEW
252
}
×
253

NEW
254
std::string StdNode::identity() const { return "0"; }
×
255

256
std::unique_ptr<data_flow::DataFlowNode> StdNode::
NEW
257
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
258
    return std::unique_ptr<
×
NEW
259
        data_flow::DataFlowNode>(new StdNode(element_id, debug_info_, vertex, parent, shape_, axes_, keepdims_));
×
NEW
260
}
×
261

262
} // namespace tensor
263
} // namespace math
264
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc