• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.02
/sdfg/src/data_flow/library_nodes/math/tensor/reduce_ops/std_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/std_node.h"
2
#include "sdfg/analysis/scope_analysis.h"
3
#include "sdfg/builder/structured_sdfg_builder.h"
4
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/mul_node.h"
5
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/sqrt_node.h"
6
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/sub_node.h"
7
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/mean_node.h"
8
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
9
#include "sdfg/structured_control_flow/block.h"
10
#include "sdfg/types/scalar.h"
11
#include "sdfg/types/utils.h"
12

13
namespace sdfg {
14
namespace math {
15
namespace tensor {
16

17
StdNode::StdNode(
18
    size_t element_id,
19
    const DebugInfo& debug_info,
20
    const graph::Vertex vertex,
21
    data_flow::DataFlowGraph& parent,
22
    const std::vector<symbolic::Expression>& shape,
23
    const std::vector<int64_t>& axes,
24
    bool keepdims
25
)
26
    : ReduceNode(element_id, debug_info, vertex, parent, LibraryNodeType_Std, shape, axes, keepdims) {}
2✔
27

28
bool StdNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
29
    auto& dataflow = this->get_parent();
2✔
30
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
2✔
31

32
    if (dataflow.in_degree(*this) != 1 || dataflow.out_degree(*this) != 1) {
2✔
33
        return false;
×
34
    }
×
35

36
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
2✔
37
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
2✔
38
    int index = parent.index(block);
2✔
39
    auto& transition = parent.at(index).second;
2✔
40

41
    auto& in_edge = *dataflow.in_edges(*this).begin();
2✔
42
    auto& out_edge = *dataflow.out_edges(*this).begin();
2✔
43
    auto& in_node = static_cast<data_flow::AccessNode&>(in_edge.src());
2✔
44
    auto& out_node = static_cast<data_flow::AccessNode&>(out_edge.dst());
2✔
45

46
    // Calculate output shape
47
    std::vector<symbolic::Expression> output_shape;
2✔
48
    std::vector<int64_t> sorted_axes = axes_;
2✔
49
    // Normalize negative axes
50
    for (auto& axis : sorted_axes) {
2✔
51
        if (axis < 0) {
2✔
NEW
52
            axis = static_cast<int64_t>(shape_.size()) + axis;
×
NEW
53
        }
×
54
        // Validate axis is in bounds
55
        if (axis < 0 || axis >= static_cast<int64_t>(shape_.size())) {
2✔
NEW
56
            throw InvalidSDFGException(
×
NEW
57
                "Library Node: Axis value out of bounds. Axis: " + std::to_string(axis) +
×
NEW
58
                " Shape size: " + std::to_string(shape_.size())
×
NEW
59
            );
×
NEW
60
        }
×
61
    }
2✔
62
    std::sort(sorted_axes.begin(), sorted_axes.end());
2✔
63

64
    for (size_t i = 0; i < shape_.size(); ++i) {
4✔
65
        bool is_axis = false;
2✔
66
        for (auto axis : sorted_axes) {
2✔
67
            if (axis == (int64_t) i) {
2✔
68
                is_axis = true;
2✔
69
                break;
2✔
70
            }
2✔
71
        }
2✔
72

73
        if (is_axis) {
2✔
74
            if (keepdims_) {
2✔
75
                output_shape.push_back(symbolic::one());
×
76
            }
×
77
        } else {
2✔
78
            output_shape.push_back(shape_[i]);
×
79
        }
×
80
    }
2✔
81

82
    types::Scalar element_type(this->primitive_type(dataflow));
2✔
83
    types::Pointer pointer_type(element_type);
2✔
84

85
    std::string tmp_x2_name = builder.find_new_name("_std_x2");
2✔
86
    builder.add_container(tmp_x2_name, pointer_type);
2✔
87
    std::string tmp_mean_x2_name = builder.find_new_name("_std_mean_x2");
2✔
88
    std::string tmp_mean_x_name = builder.find_new_name("_std_mean_x");
2✔
89

90
    symbolic::Expression bytes_in = types::get_type_size(element_type, false);
2✔
91
    for (auto& dim : this->shape_) {
2✔
92
        bytes_in = symbolic::mul(dim, bytes_in);
2✔
93
    }
2✔
94
    {
2✔
95
        auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
2✔
96
        auto& tmp_x2_name_access = builder.add_access(alloc_block, tmp_x2_name);
2✔
97
        auto& tmp_x2_name_malloc_node =
2✔
98
            builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), bytes_in);
2✔
99
        builder.add_computational_memlet(
2✔
100
            alloc_block, tmp_x2_name_malloc_node, "_ret", tmp_x2_name_access, {}, pointer_type, this->debug_info()
2✔
101
        );
2✔
102
    }
2✔
103

104
    if (!output_shape.empty()) {
2✔
NEW
105
        symbolic::Expression bytes_out = types::get_type_size(element_type, false);
×
106
        for (auto& dim : output_shape) {
×
107
            bytes_out = symbolic::mul(dim, bytes_out);
×
108
        }
×
NEW
109
        builder.add_container(tmp_mean_x2_name, pointer_type);
×
110
        {
×
111
            auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
112
            auto& tmp_mean_x2_name_access = builder.add_access(alloc_block, tmp_mean_x2_name);
×
113
            auto& tmp_mean_x2_name_malloc_node =
×
114
                builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), bytes_out);
×
115
            builder.add_computational_memlet(
×
116
                alloc_block,
×
117
                tmp_mean_x2_name_malloc_node,
×
118
                "_ret",
×
119
                tmp_mean_x2_name_access,
×
120
                {},
×
NEW
121
                pointer_type,
×
122
                this->debug_info()
×
123
            );
×
124
        }
×
125

NEW
126
        builder.add_container(tmp_mean_x_name, pointer_type);
×
127
        {
×
128
            auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
129
            auto& tmp_mean_x_name_access = builder.add_access(alloc_block, tmp_mean_x_name);
×
130
            auto& tmp_mean_x_name_malloc_node =
×
131
                builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), bytes_out);
×
132
            builder.add_computational_memlet(
×
133
                alloc_block,
×
134
                tmp_mean_x_name_malloc_node,
×
135
                "_ret",
×
136
                tmp_mean_x_name_access,
×
137
                {},
×
NEW
138
                pointer_type,
×
139
                this->debug_info()
×
140
            );
×
141
        }
×
142
    } else {
2✔
143
        builder.add_container(tmp_mean_x2_name, element_type);
2✔
144
        builder.add_container(tmp_mean_x_name, element_type);
2✔
145
    }
2✔
146

147
    // 1. X^2
148
    auto& pow_block = builder.add_block_before(parent, block, {}, this->debug_info());
2✔
149
    auto& pow_in_node = builder.add_access(pow_block, in_node.data(), this->debug_info());
2✔
150
    auto& pow_out_node = builder.add_access(pow_block, tmp_x2_name, this->debug_info());
2✔
151

152
    auto& pow_node_1 = builder.add_library_node<MulNode>(pow_block, this->debug_info(), shape_);
2✔
153
    builder
2✔
154
        .add_computational_memlet(pow_block, pow_in_node, pow_node_1, "A", {}, in_edge.base_type(), this->debug_info());
2✔
155
    builder
2✔
156
        .add_computational_memlet(pow_block, pow_in_node, pow_node_1, "B", {}, in_edge.base_type(), this->debug_info());
2✔
157
    builder
2✔
158
        .add_computational_memlet(pow_block, pow_node_1, "C", pow_out_node, {}, in_edge.base_type(), this->debug_info());
2✔
159

160
    // 2. Mean(X^2)
161
    auto& mean_x2_block = builder.add_block_before(parent, block, {}, this->debug_info());
2✔
162
    auto& mean_x2_in_node = builder.add_access(mean_x2_block, tmp_x2_name, this->debug_info());
2✔
163
    auto& mean_x2_out_node = builder.add_access(mean_x2_block, tmp_mean_x2_name, this->debug_info());
2✔
164

165
    auto& mean_node_1 = builder.add_library_node<MeanNode>(mean_x2_block, this->debug_info(), shape_, axes_, keepdims_);
2✔
166
    builder.add_computational_memlet(
2✔
167
        mean_x2_block, mean_x2_in_node, mean_node_1, "X", {}, in_edge.base_type(), this->debug_info()
2✔
168
    );
2✔
169
    builder.add_computational_memlet(
2✔
170
        mean_x2_block, mean_node_1, "Y", mean_x2_out_node, {}, out_edge.base_type(), this->debug_info()
2✔
171
    );
2✔
172

173
    // 3. Mean(X)
174
    auto& mean_x_block = builder.add_block_before(parent, block, {}, this->debug_info());
2✔
175
    auto& mean_x_in_node = builder.add_access(mean_x_block, in_node.data(), this->debug_info());
2✔
176
    auto& mean_x_out_node = builder.add_access(mean_x_block, tmp_mean_x_name, this->debug_info());
2✔
177

178
    auto& mean_node_2 = builder.add_library_node<MeanNode>(mean_x_block, this->debug_info(), shape_, axes_, keepdims_);
2✔
179
    builder.add_computational_memlet(
2✔
180
        mean_x_block, mean_x_in_node, mean_node_2, "X", {}, in_edge.base_type(), this->debug_info()
2✔
181
    );
2✔
182
    builder.add_computational_memlet(
2✔
183
        mean_x_block, mean_node_2, "Y", mean_x_out_node, {}, out_edge.base_type(), this->debug_info()
2✔
184
    );
2✔
185

186
    // 4. Mean(X)^2
187
    auto& pow_mean_x_block = builder.add_block_before(parent, block, {}, this->debug_info());
2✔
188
    auto& pow_mean_x_in_node = builder.add_access(pow_mean_x_block, tmp_mean_x_name, this->debug_info());
2✔
189
    auto& pow_mean_x_out_node = builder.add_access(pow_mean_x_block, tmp_mean_x_name, this->debug_info());
2✔
190

191
    auto& pow_node_2 = builder.add_library_node<MulNode>(pow_mean_x_block, this->debug_info(), output_shape);
2✔
192

193
    builder.add_computational_memlet(
2✔
194
        pow_mean_x_block, pow_mean_x_in_node, pow_node_2, "A", {}, out_edge.base_type(), this->debug_info()
2✔
195
    );
2✔
196
    builder.add_computational_memlet(
2✔
197
        pow_mean_x_block, pow_mean_x_in_node, pow_node_2, "B", {}, out_edge.base_type(), this->debug_info()
2✔
198
    );
2✔
199
    builder.add_computational_memlet(
2✔
200
        pow_mean_x_block, pow_node_2, "C", pow_mean_x_out_node, {}, out_edge.base_type(), this->debug_info()
2✔
201
    );
2✔
202

203
    // 5. Mean(X^2) - Mean(X)^2
204
    auto& sub_block = builder.add_block_before(parent, block, {}, this->debug_info());
2✔
205
    auto& sub_in1_node = builder.add_access(sub_block, tmp_mean_x2_name, this->debug_info());
2✔
206
    auto& sub_in2_node = builder.add_access(sub_block, tmp_mean_x_name, this->debug_info());
2✔
207
    auto& sub_out_node = builder.add_access(sub_block, out_node.data(), this->debug_info());
2✔
208

209
    auto& sub_node = builder.add_library_node<SubNode>(sub_block, this->debug_info(), output_shape);
2✔
210
    builder
2✔
211
        .add_computational_memlet(sub_block, sub_in1_node, sub_node, "A", {}, out_edge.base_type(), this->debug_info());
2✔
212
    builder
2✔
213
        .add_computational_memlet(sub_block, sub_in2_node, sub_node, "B", {}, out_edge.base_type(), this->debug_info());
2✔
214
    builder
2✔
215
        .add_computational_memlet(sub_block, sub_node, "C", sub_out_node, {}, out_edge.base_type(), this->debug_info());
2✔
216

217
    // 6. Sqrt(...)
218
    auto& sqrt_block = builder.add_block_before(parent, block, transition.assignments(), this->debug_info());
2✔
219
    auto& sqrt_in_node = builder.add_access(sqrt_block, out_node.data(), this->debug_info());
2✔
220
    auto& sqrt_out_node = builder.add_access(sqrt_block, out_node.data(), this->debug_info());
2✔
221

222
    auto& sqrt_node = builder.add_library_node<SqrtNode>(sqrt_block, this->debug_info(), output_shape);
2✔
223
    builder
2✔
224
        .add_computational_memlet(sqrt_block, sqrt_in_node, sqrt_node, "X", {}, out_edge.base_type(), this->debug_info());
2✔
225
    builder
2✔
226
        .add_computational_memlet(sqrt_block, sqrt_node, "Y", sqrt_out_node, {}, out_edge.base_type(), this->debug_info());
2✔
227

228
    // Cleanup
229
    builder.remove_memlet(block, in_edge);
2✔
230
    builder.remove_memlet(block, out_edge);
2✔
231
    builder.remove_node(block, in_node);
2✔
232
    builder.remove_node(block, out_node);
2✔
233
    builder.remove_node(block, *this);
2✔
234

235
    int last_index = parent.index(block);
2✔
236
    builder.remove_child(parent, last_index);
2✔
237

238
    return true;
2✔
239
}
2✔
240

241
bool StdNode::expand_reduction(
242
    builder::StructuredSDFGBuilder& builder,
243
    analysis::AnalysisManager& analysis_manager,
244
    structured_control_flow::Sequence& body,
245
    const std::string& input_name,
246
    const std::string& output_name,
247
    const types::Tensor& input_type,
248
    const types::Tensor& output_type,
249
    const data_flow::Subset& input_subset,
250
    const data_flow::Subset& output_subset
251
) {
×
252
    throw std::runtime_error("StdNode::expand_reduction should not be called");
×
253
}
×
254

255
std::string StdNode::identity() const { return "0"; }
×
256

257
std::unique_ptr<data_flow::DataFlowNode> StdNode::
258
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
259
    return std::unique_ptr<
×
260
        data_flow::DataFlowNode>(new StdNode(element_id, debug_info_, vertex, parent, shape_, axes_, keepdims_));
×
261
}
×
262

263
} // namespace tensor
264
} // namespace math
265
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc