• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 17651658650

11 Sep 2025 04:58PM UTC coverage: 61.012% (+1.3%) from 59.755%
17651658650

Pull #219

github

web-flow
Merge 742a12367 into f744ac9f5
Pull Request #219: stdlib Library Nodes and ConstantNodes

499 of 1681 new or added lines in 81 files covered. (29.68%)

95 existing lines in 36 files now uncovered.

9718 of 15928 relevant lines covered (61.01%)

108.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/data_flow/library_nodes/math/ml/batch_normalization.cpp
1
#include "sdfg/data_flow/library_nodes/math/ml/batch_normalization.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/analysis/scope_analysis.h"
5
#include "sdfg/builder/structured_sdfg_builder.h"
6

7
namespace sdfg {
8
namespace math {
9
namespace ml {
10

11
BatchNormalizationNode::BatchNormalizationNode(
×
12
    size_t element_id,
13
    const DebugInfoRegion &debug_info,
14
    const graph::Vertex vertex,
15
    data_flow::DataFlowGraph &parent,
16
    const std::vector<symbolic::Expression> &shape,
17
    int axis,
18
    const std::string &epsilon
19
)
20
    : MathNode(
×
21
          element_id,
×
22
          debug_info,
×
23
          vertex,
×
24
          parent,
×
25
          LibraryNodeType_BatchNormalization,
26
          {"Y"},
×
27
          {"X", "Scale", "B", "input_mean", "input_var"},
×
28
          data_flow::ImplementationType_NONE
29
      ),
NEW
30
      shape_(shape), axis_(axis), epsilon_(epsilon) {}
×
31

NEW
32
symbolic::SymbolSet BatchNormalizationNode::symbols() const {
×
NEW
33
    symbolic::SymbolSet syms;
×
NEW
34
    for (const auto &dim : shape_) {
×
NEW
35
        for (auto &atom : symbolic::atoms(dim)) {
×
NEW
36
            syms.insert(atom);
×
37
        }
38
    }
NEW
39
    return syms;
×
NEW
40
}
×
41

42
void BatchNormalizationNode::
NEW
43
    replace(const symbolic::Expression &old_expression, const symbolic::Expression &new_expression) {
×
NEW
44
    for (auto &dim : shape_) {
×
NEW
45
        dim = symbolic::subs(dim, old_expression, new_expression);
×
46
    }
NEW
47
}
×
48

NEW
49
void BatchNormalizationNode::validate(const Function &) const {}
×
50

51
bool BatchNormalizationNode::expand(builder::StructuredSDFGBuilder &builder, analysis::AnalysisManager &analysis_manager) {
×
52
    auto &dataflow = this->get_parent();
×
53
    auto &block = static_cast<structured_control_flow::Block &>(*dataflow.get_parent());
×
54

55
    auto &scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
56
    auto &parent = static_cast<structured_control_flow::Sequence &>(*scope_analysis.parent_scope(&block));
×
57
    int index = parent.index(block);
×
58
    auto &transition = parent.at(index).second;
×
59

60
    // Locate edges
61
    const data_flow::Memlet *iedge_input = nullptr;
×
62
    const data_flow::Memlet *iedge_scale = nullptr;
×
63
    const data_flow::Memlet *iedge_bias = nullptr;
×
64
    const data_flow::Memlet *iedge_mean = nullptr;
×
65
    const data_flow::Memlet *iedge_var = nullptr;
×
66
    const data_flow::Memlet *oedge_output = nullptr;
×
67
    for (const auto &edge : dataflow.in_edges(*this)) {
×
68
        if (edge.dst_conn() == "X") {
×
69
            iedge_input = &edge;
×
70
        } else if (edge.dst_conn() == "Scale") {
×
71
            iedge_scale = &edge;
×
72
        } else if (edge.dst_conn() == "B") {
×
73
            iedge_bias = &edge;
×
74
        } else if (edge.dst_conn() == "input_mean") {
×
75
            iedge_mean = &edge;
×
76
        } else if (edge.dst_conn() == "input_var") {
×
77
            iedge_var = &edge;
×
78
        }
×
79
    }
80
    for (const auto &edge : dataflow.out_edges(*this)) {
×
81
        if (edge.src_conn() == "Y") {
×
82
            oedge_output = &edge;
×
83
        }
×
84
    }
85
    if (!iedge_input || !iedge_scale || !iedge_bias || !iedge_mean || !iedge_var || !oedge_output) return false;
×
86

87
    std::string input_name = static_cast<const data_flow::AccessNode &>(iedge_input->src()).data();
×
88
    std::string scale_name = static_cast<const data_flow::AccessNode &>(iedge_scale->src()).data();
×
89
    std::string bias_name = static_cast<const data_flow::AccessNode &>(iedge_bias->src()).data();
×
90
    std::string mean_name = static_cast<const data_flow::AccessNode &>(iedge_mean->src()).data();
×
91
    std::string var_name = static_cast<const data_flow::AccessNode &>(iedge_var->src()).data();
×
92
    std::string output_name = static_cast<const data_flow::AccessNode &>(oedge_output->dst()).data();
×
93

94
    // Create new sequence before
95
    auto &new_sequence = builder.add_sequence_before(
×
96
        parent, block, transition.assignments(), builder.debug_info().get_region(block.debug_info().indices())
×
97
    );
98
    structured_control_flow::Sequence *last_scope = &new_sequence;
×
99

UNCOV
100
    std::vector<symbolic::Expression> loop_syms;
×
101
    structured_control_flow::Map *last_map = nullptr;
×
NEW
102
    for (size_t d = 0; d < this->shape_.size(); ++d) {
×
103
        std::string indvar_str = builder.find_new_name("_i");
×
104
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
×
105
        auto indvar = symbolic::symbol(indvar_str);
×
NEW
106
        auto init = symbolic::zero();
×
107
        auto update = symbolic::add(indvar, symbolic::one());
×
NEW
108
        auto cond = symbolic::Lt(indvar, this->shape_[d]);
×
109
        last_map = &builder.add_map(
×
110
            *last_scope,
×
111
            indvar,
112
            cond,
UNCOV
113
            init,
×
114
            update,
115
            structured_control_flow::ScheduleType_Sequential::create(),
×
116
            {},
×
117
            builder.subject().debug_info().get_region(block.debug_info().indices())
×
118
        );
119
        last_scope = &last_map->root();
×
120
        loop_syms.push_back(indvar);
×
121
    }
×
122

123
    // Create normalization block
124
    auto &norm_block = builder.add_block(*last_scope);
×
125

126
    // Create access nodes for normalization
127
    auto &input_access_norm = builder.add_access(norm_block, input_name);
×
128
    auto &scale_access_norm = builder.add_access(norm_block, scale_name);
×
129
    auto &bias_access_norm = builder.add_access(norm_block, bias_name);
×
130
    auto &mean_access_norm = builder.add_access(norm_block, mean_name);
×
131
    auto &var_access_norm = builder.add_access(norm_block, var_name);
×
132
    auto &output_access_norm = builder.add_access(norm_block, output_name);
×
133

134
    // Add epsilon to variance and compute standard deviation
135
    auto &add_epsilon_tasklet =
×
136
        builder.add_tasklet(norm_block, data_flow::TaskletCode::add, "_out", {"_in1", epsilon_});
×
137
    auto &var_eps_access = builder.add_access(norm_block, builder.find_new_name("_var_eps"));
×
138
    builder.add_computational_memlet(
×
139
        norm_block, var_access_norm, add_epsilon_tasklet, "_in1", loop_syms, iedge_var->base_type()
×
140
    );
141
    builder
×
142
        .add_computational_memlet(norm_block, add_epsilon_tasklet, "_out", var_eps_access, {}, iedge_var->base_type());
×
143

144
    auto &sqrt_tasklet = builder.add_tasklet(norm_block, data_flow::TaskletCode::sqrt, "_out", {"_in"});
×
145
    auto &std_dev_access = builder.add_access(norm_block, builder.find_new_name("_std_dev"));
×
146
    builder.add_computational_memlet(norm_block, var_eps_access, sqrt_tasklet, "_in", {}, iedge_var->base_type());
×
147
    builder.add_computational_memlet(norm_block, sqrt_tasklet, "_out", std_dev_access, {}, iedge_var->base_type());
×
148

149
    // Normalize: (x - mean) / std_dev
150
    auto &sub_norm_tasklet = builder.add_tasklet(norm_block, data_flow::TaskletCode::sub, "_out", {"_in1", "_in2"});
×
151
    auto &centered_access = builder.add_access(norm_block, builder.find_new_name("_centered"));
×
152
    builder.add_computational_memlet(
×
153
        norm_block, input_access_norm, sub_norm_tasklet, "_in1", loop_syms, iedge_input->base_type()
×
154
    );
155
    builder.add_computational_memlet(
×
156
        norm_block, mean_access_norm, sub_norm_tasklet, "_in2", loop_syms, iedge_mean->base_type()
×
157
    );
158
    builder
×
159
        .add_computational_memlet(norm_block, sub_norm_tasklet, "_out", centered_access, {}, iedge_input->base_type());
×
160

161
    auto &div_norm_tasklet = builder.add_tasklet(norm_block, data_flow::TaskletCode::div, "_out", {"_in1", "_in2"});
×
162
    auto &normalized_access = builder.add_access(norm_block, builder.find_new_name("_normalized"));
×
163
    builder
×
164
        .add_computational_memlet(norm_block, centered_access, div_norm_tasklet, "_in1", {}, iedge_input->base_type());
×
165
    builder
×
166
        .add_computational_memlet(norm_block, std_dev_access, div_norm_tasklet, "_in2", loop_syms, iedge_var->base_type());
×
167
    builder
×
168
        .add_computational_memlet(norm_block, div_norm_tasklet, "_out", normalized_access, {}, iedge_input->base_type());
×
169

170
    // Apply scale and bias: scale * normalized + bias
171
    auto &mul_scale_tasklet = builder.add_tasklet(norm_block, data_flow::TaskletCode::mul, "_out", {"_in1", "_in2"});
×
172
    auto &scaled_access = builder.add_access(norm_block, builder.find_new_name("_scaled"));
×
173
    builder
×
174
        .add_computational_memlet(norm_block, normalized_access, mul_scale_tasklet, "_in1", {}, iedge_input->base_type());
×
175
    builder.add_computational_memlet(
×
176
        norm_block, scale_access_norm, mul_scale_tasklet, "_in2", loop_syms, iedge_scale->base_type()
×
177
    );
178
    builder.add_computational_memlet(norm_block, mul_scale_tasklet, "_out", scaled_access, {}, iedge_input->base_type());
×
179

180
    auto &add_bias_tasklet = builder.add_tasklet(norm_block, data_flow::TaskletCode::add, "_out", {"_in1", "_in2"});
×
181
    builder.add_computational_memlet(norm_block, scaled_access, add_bias_tasklet, "_in1", {}, iedge_input->base_type());
×
182
    builder.add_computational_memlet(
×
183
        norm_block, bias_access_norm, add_bias_tasklet, "_in2", loop_syms, iedge_bias->base_type()
×
184
    );
185
    builder.add_computational_memlet(
×
186
        norm_block, add_bias_tasklet, "_out", output_access_norm, loop_syms, oedge_output->base_type()
×
187
    );
188

189
    // Cleanup old block
190
    builder.remove_memlet(block, *iedge_input);
×
191
    builder.remove_memlet(block, *iedge_scale);
×
192
    if (iedge_bias) {
×
193
        builder.remove_memlet(block, *iedge_bias);
×
194
    }
×
195
    builder.remove_memlet(block, *iedge_mean);
×
196
    builder.remove_memlet(block, *iedge_var);
×
197
    builder.remove_memlet(block, *oedge_output);
×
198
    builder.remove_node(block, *this);
×
199
    builder.remove_child(parent, index + 1);
×
200

201
    return true;
×
202
}
×
203

204
std::unique_ptr<data_flow::DataFlowNode> BatchNormalizationNode::
205
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph &parent) const {
×
206
    return std::unique_ptr<data_flow::DataFlowNode>(
×
NEW
207
        new BatchNormalizationNode(element_id, this->debug_info(), vertex, parent, this->shape_, axis_, epsilon_)
×
208
    );
209
}
×
210

211
nlohmann::json BatchNormalizationNodeSerializer::serialize(const data_flow::LibraryNode &library_node) {
×
212
    const BatchNormalizationNode &node = static_cast<const BatchNormalizationNode &>(library_node);
×
213
    nlohmann::json j;
×
214

215
    j["code"] = node.code().value();
×
216
    j["axis"] = node.axis();
×
217
    j["epsilon"] = node.epsilon();
×
218

NEW
219
    serializer::JSONSerializer serializer;
×
NEW
220
    j["shape"] = nlohmann::json::array();
×
NEW
221
    for (auto &dim : node.shape()) {
×
NEW
222
        j["shape"].push_back(serializer.expression(dim));
×
223
    }
224

225
    return j;
×
226
}
×
227

228
data_flow::LibraryNode &BatchNormalizationNodeSerializer::deserialize(
×
229
    const nlohmann::json &j, builder::StructuredSDFGBuilder &builder, structured_control_flow::Block &parent
230
) {
231
    auto code = j["code"].get<std::string>();
×
232
    if (code != LibraryNodeType_BatchNormalization.value()) {
×
233
        throw std::runtime_error("Invalid library node code");
×
234
    }
235

236
    sdfg::serializer::JSONSerializer serializer;
×
237
    DebugInfoRegion debug_info = serializer.json_to_debug_info_region(j["debug_info"], builder.debug_info());
×
238

NEW
239
    std::vector<symbolic::Expression> shape;
×
NEW
240
    for (const auto &dim : j["shape"]) {
×
NEW
241
        shape.push_back(SymEngine::Expression(dim.get<std::string>()));
×
242
    }
243

244
    auto axis = j["axis"].get<int>();
×
245
    auto epsilon = j["epsilon"].get<std::string>();
×
246

NEW
247
    return builder.add_library_node<BatchNormalizationNode>(parent, debug_info, shape, axis, epsilon);
×
248
}
×
249

250
} // namespace ml
251
} // namespace math
252
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc