• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 24128050488

08 Apr 2026 09:23AM UTC coverage: 64.832% (-0.02%) from 64.848%
24128050488

push

github

web-flow
Batchnorm Node (#655)

* Batchnorm Node:

 + symbolic flop estimation
 ~ fixed: expand was missing the division operation.
 ~ ensure clone also respects impl type

* Added batchnorm2d test with explicit input values, to tease out the last bug,

 + toStr() for tensor layout gives layout details

10 of 29 new or added lines in 2 files covered. (34.48%)

1 existing line in 1 file now uncovered.

28998 of 44728 relevant lines covered (64.83%)

603.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.31
/sdfg/src/data_flow/library_nodes/math/tensor/batchnorm_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/batchnorm_node.h"
2

3
#include "sdfg/analysis/scope_analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5
#include "sdfg/data_flow/access_node.h"
6
#include "sdfg/data_flow/library_nodes/math/cmath/cmath_node.h"
7
#include "sdfg/structured_control_flow/block.h"
8
#include "sdfg/structured_control_flow/structured_loop.h"
9

10
namespace sdfg::math::tensor {
11

12

13
BatchNormNode::BatchNormNode(
14
    size_t element_id,
15
    const DebugInfo& debug_info,
16
    graph::Vertex vertex,
17
    data_flow::DataFlowGraph& parent,
18
    TensorLayout layout,
19
    types::PrimitiveType quantization,
20
    data_flow::ImplementationType impl_type
21
)
22
    : TensorNode(
1✔
23
          element_id,
1✔
24
          debug_info,
1✔
25
          vertex,
1✔
26
          parent,
1✔
27
          LibraryNodeType_BatchNorm,
1✔
28
          {},
1✔
29
          {"Batch", "Var", "E", "Gamma", "Beta", "epsilon", "B_out"},
1✔
30
          std::move(impl_type)
1✔
31
      ),
1✔
32
      layout_(std::move(layout)), quantization_(quantization) {}
1✔
33

34
symbolic::SymbolSet BatchNormNode::symbols() const {
×
35
    symbolic::SymbolSet syms;
×
36
    layout_.collect_symbols(syms);
×
37
    return syms;
×
38
}
×
39

40
types::PrimitiveType BatchNormNode::quantization() const { return quantization_; }
×
41

42
void BatchNormNode::set_quantization(const types::PrimitiveType quant) { quantization_ = quant; }
×
43

44
void BatchNormNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
45
    layout_.replace_symbols(old_expression, new_expression);
×
46
}
×
47

48
std::unique_ptr<data_flow::DataFlowNode> BatchNormNode::
49
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
50
    return std::unique_ptr<data_flow::DataFlowNode>(new BatchNormNode(
×
NEW
51
        element_id, debug_info(), vertex, parent, this->layout_, this->quantization_, this->implementation_type_
×
NEW
52
    ));
×
UNCOV
53
}
×
54

55
std::string BatchNormNode::toStr() const { return "BatchNorm(" + layout_.toStr() + ")"; }
×
56

57
struct InputContainerInfo {
58
    std::string name;
59
    bool is_const = false;
60
    const data_flow::Memlet* memlet;
61
    const data_flow::AccessNode* access_to_remove = nullptr;
62

63
    void remove_old(builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& block) const {
7✔
64
        if (memlet) {
7✔
65
            builder.remove_memlet(block, *memlet);
7✔
66
        }
7✔
67
        if (access_to_remove) {
7✔
68
            builder.remove_node(block, *access_to_remove);
7✔
69
        }
7✔
70
    }
7✔
71
};
72

73
InputContainerInfo find_usable_input_access_node(
74
    data_flow::DataFlowGraph& dataflow, data_flow::LibraryNode& node, const std::string& input_conn
75
) {
7✔
76
    auto* edge = dataflow.in_edge_for_connector(node, input_conn);
7✔
77
    if (!edge) {
7✔
78
        throw InvalidSDFGException(node.toStr() + " requires input on " + input_conn);
×
79
    }
×
80
    auto* access_node = dynamic_cast<const data_flow::AccessNode*>(&edge->src());
7✔
81
    if (!access_node) {
7✔
82
        throw InvalidSDFGException(node.toStr() + " requires input on " + input_conn + " to be an access node");
×
83
    }
×
84

85
    return {
7✔
86
        .name = access_node->data(),
7✔
87
        .is_const = !!dynamic_cast<const data_flow::ConstantNode*>(&edge->src()),
7✔
88
        .memlet = edge,
7✔
89
        .access_to_remove = access_node
7✔
90
    };
7✔
91
}
7✔
92

93
struct BuilderMapDim {
94
    symbolic::Expression indvar;
95
    structured_control_flow::StructuredLoop& loop;
96
    structured_control_flow::Sequence& seq;
97
};
98

99
std::vector<BuilderMapDim> create_maps(
100
    builder::StructuredSDFGBuilder& builder,
101
    const std::vector<symbolic::Expression>& sizes,
102
    structured_control_flow::Sequence& block
103
) {
1✔
104
    std::vector<BuilderMapDim> scopes;
1✔
105

106
    Sequence* last_scope = &block;
1✔
107

108
    for (size_t i = 0; i < sizes.size(); i++) {
5✔
109
        std::string indvar_str = builder.find_new_name("_i");
4✔
110
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
4✔
111

112
        auto indvar = symbolic::symbol(indvar_str);
4✔
113
        auto init = symbolic::zero();
4✔
114
        auto update = symbolic::add(indvar, symbolic::one());
4✔
115
        auto condition = symbolic::Lt(indvar, sizes.at(i));
4✔
116
        auto& last_map = builder.add_map(
4✔
117
            *last_scope,
4✔
118
            indvar,
4✔
119
            condition,
4✔
120
            init,
4✔
121
            update,
4✔
122
            structured_control_flow::ScheduleType_Sequential::create(),
4✔
123
            {},
4✔
124
            block.debug_info()
4✔
125
        );
4✔
126
        auto& seq = last_map.root();
4✔
127
        last_scope = &seq;
4✔
128

129
        scopes.push_back(
4✔
130
            {.indvar = indvar, .loop = dynamic_cast<structured_control_flow::StructuredLoop&>(last_map), .seq = seq}
4✔
131
        );
4✔
132
    }
4✔
133

134
    return scopes;
1✔
135
}
1✔
136

137
std::string
138
create_temp_var(builder::StructuredSDFGBuilder& builder, const std::string& prefix, int gen, const types::IType& type) {
5✔
139
    std::string n = prefix + "_" + std::to_string(gen);
5✔
140
    auto name = builder.find_new_name(n);
5✔
141
    builder.add_container(name, type);
5✔
142
    return name;
5✔
143
}
5✔
144

145
bool BatchNormNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
146
    auto& dataflow = this->get_parent();
1✔
147
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
148

149
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
150
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
1✔
151
    int index = parent.index(block);
1✔
152
    auto& transition = parent.at(index).second;
1✔
153

154
    auto batch_in = find_usable_input_access_node(dataflow, *this, "Batch");
1✔
155
    auto& data_type = batch_in.memlet->base_type();
1✔
156
    types::Scalar scalar_type(data_type.primitive_type());
1✔
157
    types::Tensor tensor_1d(scalar_type, {num_features()}, {symbolic::one()}); // TODO verify / get from inputs
1✔
158
    std::string temp_var_prefix = "_batchn_tmp";
1✔
159
    int tmp_idx = 0;
1✔
160
    auto var_in = find_usable_input_access_node(dataflow, *this, "Var");
1✔
161
    auto e_in = find_usable_input_access_node(dataflow, *this, "E");
1✔
162
    auto gamma_in = find_usable_input_access_node(dataflow, *this, "Gamma");
1✔
163
    auto beta_in = find_usable_input_access_node(dataflow, *this, "Beta");
1✔
164
    auto result_ptr_in = find_usable_input_access_node(dataflow, *this, "B_out");
1✔
165
    auto eps_in = find_usable_input_access_node(dataflow, *this, "epsilon");
1✔
166

167
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), debug_info());
1✔
168

169
    auto loop_dims = create_maps(builder, layout_.shape(), new_sequence);
1✔
170

171
    auto& c_dim = loop_dims.at(1);
1✔
172
    std::vector<symbolic::Expression> c_subset{c_dim.indvar};
1✔
173
    auto interm_name = builder.find_new_name("_b_sqrt_div");
1✔
174
    builder.add_container(interm_name, scalar_type);
1✔
175
    auto& inter_block = builder.add_block_before(
1✔
176
        c_dim.seq, static_cast<structured_control_flow::ControlFlowNode&>(loop_dims.at(2).loop), {}, DebugInfo()
1✔
177
    );
1✔
178

179
    auto& var_elem_in = builder.add_access(inter_block, var_in.name);
1✔
180
    data_flow::AccessNode& epsilon_const = eps_in.is_const ? builder.add_constant(inter_block, eps_in.name, scalar_type)
1✔
181
                                                           : builder.add_access(inter_block, eps_in.name);
1✔
182

183
    auto& add_eps_op = builder.add_tasklet(inter_block, data_flow::fp_add, "_out", {"var", "eps"}, debug_info());
1✔
184

185
    builder.add_computational_memlet(inter_block, var_elem_in, add_eps_op, "var", c_subset, tensor_1d);
1✔
186
    builder.add_computational_memlet(inter_block, epsilon_const, add_eps_op, "eps", {}, scalar_type);
1✔
187

188
    auto tmp_eps_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
189
    auto& tmp_eps = builder.add_access(inter_block, tmp_eps_name);
1✔
190

191
    builder.add_computational_memlet(inter_block, add_eps_op, "_out", tmp_eps, {}, scalar_type);
1✔
192

193
    auto tmp_sqrt_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
194
    auto& tmp_sqrt = builder.add_access(inter_block, tmp_sqrt_name);
1✔
195

196
    auto& sqrt_op = builder.add_library_node<
1✔
197
        cmath::CMathNode>(inter_block, debug_info(), cmath::CMathFunction::sqrt, data_type.primitive_type());
1✔
198

199
    builder.add_computational_memlet(inter_block, tmp_eps, sqrt_op, "_in1", {}, scalar_type);
1✔
200

201
    builder.add_computational_memlet(inter_block, sqrt_op, "_out", tmp_sqrt, {}, scalar_type);
1✔
202

203
    auto& one_const = builder.add_constant(inter_block, "1.0", scalar_type);
1✔
204
    auto& div_op = builder.add_tasklet(inter_block, data_flow::fp_div, "_out", {"one", "sqrt"});
1✔
205
    builder.add_computational_memlet(inter_block, one_const, div_op, "one", {}, scalar_type);
1✔
206
    builder.add_computational_memlet(inter_block, tmp_sqrt, div_op, "sqrt", {}, scalar_type);
1✔
207

208
    auto& interm_store = builder.add_access(inter_block, interm_name);
1✔
209
    builder.add_computational_memlet(inter_block, div_op, "_out", interm_store, {}, scalar_type);
1✔
210

211
    auto& innermost_dim = loop_dims.at(layout_.dims() - 1);
1✔
212

213
    std::vector<symbolic::Expression> innermost_subset;
1✔
214
    for (auto& builder_map_dim : loop_dims) {
4✔
215
        innermost_subset.push_back(builder_map_dim.indvar);
4✔
216
    }
4✔
217

218
    auto& innermost_block = builder.add_block(innermost_dim.seq);
1✔
219
    auto& x_in = builder.add_access(innermost_block, batch_in.name);
1✔
220
    auto& interm_in = builder.add_access(innermost_block, interm_name);
1✔
221
    auto& e_elem_in = builder.add_access(innermost_block, e_in.name);
1✔
222
    auto& gamma_elem_in = builder.add_access(innermost_block, gamma_in.name);
1✔
223
    auto& beta_elem_in = builder.add_access(innermost_block, beta_in.name);
1✔
224

225
    auto& result_ptr_out_elem = builder.add_access(innermost_block, result_ptr_in.name);
1✔
226

227
    auto& sub_op = builder.add_tasklet(innermost_block, data_flow::fp_sub, "_out", {"x", "e"}, debug_info());
1✔
228

229
    builder.add_computational_memlet(innermost_block, x_in, sub_op, "x", innermost_subset, data_type);
1✔
230
    builder.add_computational_memlet(innermost_block, e_elem_in, sub_op, "e", c_subset, tensor_1d);
1✔
231
    auto tmp_sub_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
232
    auto& tmp_sub = builder.add_access(innermost_block, tmp_sub_name);
1✔
233
    builder.add_computational_memlet(innermost_block, sub_op, "_out", tmp_sub, {}, scalar_type);
1✔
234

235
    auto& mul_interm_op = builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"num", "den"}, debug_info());
1✔
236

237
    builder.add_computational_memlet(innermost_block, tmp_sub, mul_interm_op, "num", {}, scalar_type);
1✔
238
    builder.add_computational_memlet(innermost_block, interm_in, mul_interm_op, "den", {}, scalar_type);
1✔
239
    auto tmp_interm = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
240
    auto& tmp_mul_interm = builder.add_access(innermost_block, tmp_interm);
1✔
241
    builder.add_computational_memlet(innermost_block, mul_interm_op, "_out", tmp_mul_interm, {}, scalar_type);
1✔
242

243
    auto& mul_gamma_op = builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"frac", "g"}, debug_info());
1✔
244

245
    builder.add_computational_memlet(innermost_block, tmp_mul_interm, mul_gamma_op, "frac", {}, scalar_type);
1✔
246
    builder.add_computational_memlet(innermost_block, gamma_elem_in, mul_gamma_op, "g", c_subset, tensor_1d);
1✔
247

248
    auto tmp_gamma = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
249
    auto& tmp_mul_gamma = builder.add_access(innermost_block, tmp_gamma);
1✔
250
    builder.add_computational_memlet(innermost_block, mul_gamma_op, "_out", tmp_mul_gamma, {}, scalar_type);
1✔
251

252
    auto& add_beta_op = builder.add_tasklet(innermost_block, data_flow::fp_add, "_out", {"_in", "b"}, debug_info());
1✔
253

254
    builder.add_computational_memlet(innermost_block, tmp_mul_gamma, add_beta_op, "_in", {}, scalar_type);
1✔
255
    builder.add_computational_memlet(innermost_block, beta_elem_in, add_beta_op, "b", c_subset, tensor_1d);
1✔
256
    builder
1✔
257
        .add_computational_memlet(innermost_block, add_beta_op, "_out", result_ptr_out_elem, innermost_subset, data_type);
1✔
258

259
    batch_in.remove_old(builder, block);
1✔
260
    var_in.remove_old(builder, block);
1✔
261
    e_in.remove_old(builder, block);
1✔
262
    eps_in.remove_old(builder, block);
1✔
263
    gamma_in.remove_old(builder, block);
1✔
264
    beta_in.remove_old(builder, block);
1✔
265
    result_ptr_in.remove_old(builder, block);
1✔
266

267
    builder.remove_node(block, *this);
1✔
268
    assert(dataflow.nodes().size() == 0 && "At expand time, no other nodes may be in the same graph");
1✔
269
    builder.remove_child(parent, index + 1);
1✔
270

271
    return true;
1✔
272
}
1✔
273

NEW
274
symbolic::Expression BatchNormNode::flop() const {
×
NEW
275
    auto inner_elems = symbolic::mul(layout_.get_dim_innermost(0), layout_.get_dim_innermost(1));
×
NEW
276
    auto outer_elems = symbolic::mul(layout_.shape().at(0), layout_.shape().at(1));
×
277

278
    // (x-e) * sqrt_pre_calc * g + b = 4 flops
NEW
279
    auto inner_flops = symbolic::mul(symbolic::integer(4), inner_elems);
×
280
    // sqrt_pre_calc = 1/sqrt(var + eps) // 3 flops
NEW
281
    auto outer_flops = symbolic::mul(symbolic::add(inner_flops, symbolic::integer(3)), outer_elems);
×
NEW
282
    return outer_flops;
×
NEW
283
}
×
284

285
nlohmann::json BatchNormNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
286
    auto& node = static_cast<const BatchNormNode&>(library_node);
×
287
    nlohmann::json j;
×
288

289
    j["code"] = node.code().value();
×
290

291
    node.batch_layout().serialize_to_json(j["batch_layout"]);
×
292

293
    j["batch_quant"] = node.quantization();
×
294

295
    return j;
×
296
}
×
297

298
data_flow::LibraryNode& BatchNormNodeSerializer::deserialize(
299
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
300
) {
×
301
    auto layout = TensorLayout::deserialize_from_json(j.at("batch_layout"));
×
302
    auto quant = j.at("batch_quant").get<types::PrimitiveType>();
×
303

304
    serializer::JSONSerializer serializer;
×
305
    auto deb_info = serializer.json_to_debug_info(j.at("debug_info"));
×
306

307
    return builder.add_library_node<BatchNormNode>(parent, deb_info, layout, quant);
×
308
}
×
309

310
} // namespace sdfg::math::tensor
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc