• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27981272983

22 Jun 2026 08:18PM UTC coverage: 61.754% (-0.03%) from 61.782%
27981272983

Pull #781

github

web-flow
Merge bddaa3724 into fe87d162b
Pull Request #781: Extend Segformer benchmarks setup

987 of 1432 new or added lines in 62 files covered. (68.92%)

9 existing lines in 7 files now uncovered.

38121 of 61730 relevant lines covered (61.75%)

993.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.88
/sdfg/src/data_flow/library_nodes/math/tensor/batchnorm_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/batchnorm_node.h"
2

3
#include "sdfg/builder/structured_sdfg_builder.h"
4
#include "sdfg/data_flow/access_node.h"
5
#include "sdfg/data_flow/library_nodes/math/cmath/cmath_node.h"
6
#include "sdfg/data_flow/library_nodes/math/tensor/tensor_expansion_utils.h"
7
#include "sdfg/structured_control_flow/block.h"
8
#include "sdfg/structured_control_flow/structured_loop.h"
9

10
namespace sdfg::math::tensor {
11

12

13
BatchNormNode::BatchNormNode(
14
    size_t element_id,
15
    const DebugInfo& debug_info,
16
    graph::Vertex vertex,
17
    data_flow::DataFlowGraph& parent,
18
    TensorLayout layout,
19
    QuantizationType quantization,
20
    data_flow::ImplementationType impl_type
21
)
22
    : TensorNode(
22✔
23
          element_id,
22✔
24
          debug_info,
22✔
25
          vertex,
22✔
26
          parent,
22✔
27
          LibraryNodeType_BatchNorm,
22✔
28
          {},
22✔
29
          {"Batch", "Var", "E", "Gamma", "Beta", "epsilon", "B_out"},
22✔
30
          std::move(impl_type)
22✔
31
      ),
22✔
32
      layout_(std::move(layout)), quantization_(quantization) {}
22✔
33

34
symbolic::SymbolSet BatchNormNode::symbols() const {
24✔
35
    symbolic::SymbolSet syms;
24✔
36
    layout_.collect_symbols(syms);
24✔
37
    return syms;
24✔
38
}
24✔
39

40
types::PrimitiveType BatchNormNode::quantization() const { return quantization_; }
×
41

42
void BatchNormNode::set_quantization(const types::PrimitiveType quant) { quantization_ = quant; }
×
43

44
void BatchNormNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
45
    layout_.replace_symbols(old_expression, new_expression);
×
46
}
×
47

NEW
48
void BatchNormNode::replace(const symbolic::ExpressionMapping& replacements) { layout_.replace_symbols(replacements); }
×
49

50
std::unique_ptr<data_flow::DataFlowNode> BatchNormNode::
51
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
52
    return std::unique_ptr<data_flow::DataFlowNode>(new BatchNormNode(
×
53
        element_id, debug_info(), vertex, parent, this->layout_, this->quantization_, this->implementation_type_
×
54
    ));
×
55
}
×
56

57
std::string BatchNormNode::toStr() const { return "BatchNorm(" + layout_.toStr() + ")"; }
×
58

59
bool BatchNormNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
60
    // CPU implementation of batchnorm:
61
    if (false) {
1✔
62
        auto& dataflow = this->get_parent();
×
63
        auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
64

65
        auto& parent = static_cast<structured_control_flow::Sequence&>(*block.get_parent());
×
66
        int index = parent.index(block);
×
67
        auto& transition = parent.at(index).second;
×
68

69
        auto batch_in = find_usable_input_access_node(dataflow, *this, "Batch");
×
70
        auto& data_type = batch_in.memlet->base_type();
×
71
        types::Scalar scalar_type(data_type.primitive_type());
×
72
        types::Tensor tensor_1d(scalar_type, {num_features()}, {symbolic::one()}); // TODO verify / get from inputs
×
73
        std::string temp_var_prefix = "_batchn_tmp";
×
74
        int tmp_idx = 0;
×
75
        auto var_in = find_usable_input_access_node(dataflow, *this, "Var");
×
76
        auto e_in = find_usable_input_access_node(dataflow, *this, "E");
×
77
        auto gamma_in = find_usable_input_access_node(dataflow, *this, "Gamma");
×
78
        auto beta_in = find_usable_input_access_node(dataflow, *this, "Beta");
×
79
        auto result_ptr_in = find_usable_input_access_node(dataflow, *this, "B_out");
×
80
        auto eps_in = find_usable_input_access_node(dataflow, *this, "epsilon");
×
81

82
        auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), debug_info());
×
83

84
        auto loop_dims = create_maps(builder, layout_.shape(), new_sequence);
×
85

86
        auto& c_dim = loop_dims.at(1);
×
87
        std::vector<symbolic::Expression> c_subset{c_dim.indvar};
×
88
        auto interm_name = builder.find_new_name("_b_sqrt_div");
×
89
        builder.add_container(interm_name, scalar_type);
×
90
        auto& inter_block = builder.add_block_before(
×
91
            c_dim.seq, static_cast<structured_control_flow::ControlFlowNode&>(loop_dims.at(2).loop), {}, DebugInfo()
×
92
        );
×
93

94
        auto& var_elem_in = builder.add_access(inter_block, var_in.name);
×
95
        data_flow::AccessNode& epsilon_const = eps_in.is_const
×
96
                                                   ? builder.add_constant(inter_block, eps_in.name, scalar_type)
×
97
                                                   : builder.add_access(inter_block, eps_in.name);
×
98

99
        auto& add_eps_op = builder.add_tasklet(inter_block, data_flow::fp_add, "_out", {"var", "eps"}, debug_info());
×
100

101
        builder.add_computational_memlet(inter_block, var_elem_in, add_eps_op, "var", c_subset, tensor_1d);
×
102
        builder.add_computational_memlet(inter_block, epsilon_const, add_eps_op, "eps", {}, scalar_type);
×
103

104
        auto tmp_eps_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
105
        auto& tmp_eps = builder.add_access(inter_block, tmp_eps_name);
×
106

107
        builder.add_computational_memlet(inter_block, add_eps_op, "_out", tmp_eps, {}, scalar_type);
×
108

109
        auto tmp_sqrt_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
110
        auto& tmp_sqrt = builder.add_access(inter_block, tmp_sqrt_name);
×
111

112
        auto& sqrt_op = builder.add_library_node<
×
113
            cmath::CMathNode>(inter_block, debug_info(), cmath::CMathFunction::sqrt, data_type.primitive_type());
×
114

115
        builder.add_computational_memlet(inter_block, tmp_eps, sqrt_op, "_in1", {}, scalar_type);
×
116

117
        builder.add_computational_memlet(inter_block, sqrt_op, "_out", tmp_sqrt, {}, scalar_type);
×
118

119
        auto& one_const = builder.add_constant(inter_block, "1.0", scalar_type);
×
120
        auto& div_op = builder.add_tasklet(inter_block, data_flow::fp_div, "_out", {"one", "sqrt"});
×
121
        builder.add_computational_memlet(inter_block, one_const, div_op, "one", {}, scalar_type);
×
122
        builder.add_computational_memlet(inter_block, tmp_sqrt, div_op, "sqrt", {}, scalar_type);
×
123

124
        auto& interm_store = builder.add_access(inter_block, interm_name);
×
125
        builder.add_computational_memlet(inter_block, div_op, "_out", interm_store, {}, scalar_type);
×
126

127
        auto& innermost_dim = loop_dims.at(layout_.dims() - 1);
×
128

129
        std::vector<symbolic::Expression> innermost_subset;
×
130
        for (auto& builder_map_dim : loop_dims) {
×
131
            innermost_subset.push_back(builder_map_dim.indvar);
×
132
        }
×
133

134
        auto& innermost_block = builder.add_block(innermost_dim.seq);
×
135
        auto& x_in = builder.add_access(innermost_block, batch_in.name);
×
136
        auto& interm_in = builder.add_access(innermost_block, interm_name);
×
137
        auto& e_elem_in = builder.add_access(innermost_block, e_in.name);
×
138
        auto& gamma_elem_in = builder.add_access(innermost_block, gamma_in.name);
×
139
        auto& beta_elem_in = builder.add_access(innermost_block, beta_in.name);
×
140

141
        auto& result_ptr_out_elem = builder.add_access(innermost_block, result_ptr_in.name);
×
142

143
        auto& sub_op = builder.add_tasklet(innermost_block, data_flow::fp_sub, "_out", {"x", "e"}, debug_info());
×
144

145
        builder.add_computational_memlet(innermost_block, x_in, sub_op, "x", innermost_subset, data_type);
×
146
        builder.add_computational_memlet(innermost_block, e_elem_in, sub_op, "e", c_subset, tensor_1d);
×
147
        auto tmp_sub_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
148
        auto& tmp_sub = builder.add_access(innermost_block, tmp_sub_name);
×
149
        builder.add_computational_memlet(innermost_block, sub_op, "_out", tmp_sub, {}, scalar_type);
×
150

151
        auto& mul_interm_op =
×
152
            builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"num", "den"}, debug_info());
×
153

154
        builder.add_computational_memlet(innermost_block, tmp_sub, mul_interm_op, "num", {}, scalar_type);
×
155
        builder.add_computational_memlet(innermost_block, interm_in, mul_interm_op, "den", {}, scalar_type);
×
156
        auto tmp_interm = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
157
        auto& tmp_mul_interm = builder.add_access(innermost_block, tmp_interm);
×
158
        builder.add_computational_memlet(innermost_block, mul_interm_op, "_out", tmp_mul_interm, {}, scalar_type);
×
159

160
        auto& mul_gamma_op =
×
161
            builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"frac", "g"}, debug_info());
×
162

163
        builder.add_computational_memlet(innermost_block, tmp_mul_interm, mul_gamma_op, "frac", {}, scalar_type);
×
164
        builder.add_computational_memlet(innermost_block, gamma_elem_in, mul_gamma_op, "g", c_subset, tensor_1d);
×
165

166
        auto tmp_gamma = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
167
        auto& tmp_mul_gamma = builder.add_access(innermost_block, tmp_gamma);
×
168
        builder.add_computational_memlet(innermost_block, mul_gamma_op, "_out", tmp_mul_gamma, {}, scalar_type);
×
169

170
        auto& add_beta_op = builder.add_tasklet(innermost_block, data_flow::fp_add, "_out", {"_in", "b"}, debug_info());
×
171

172
        builder.add_computational_memlet(innermost_block, tmp_mul_gamma, add_beta_op, "_in", {}, scalar_type);
×
173
        builder.add_computational_memlet(innermost_block, beta_elem_in, add_beta_op, "b", c_subset, tensor_1d);
×
174
        builder.add_computational_memlet(
×
175
            innermost_block, add_beta_op, "_out", result_ptr_out_elem, innermost_subset, data_type
×
176
        );
×
177

178
        batch_in.remove_old(builder, block);
×
179
        var_in.remove_old(builder, block);
×
180
        e_in.remove_old(builder, block);
×
181
        eps_in.remove_old(builder, block);
×
182
        gamma_in.remove_old(builder, block);
×
183
        beta_in.remove_old(builder, block);
×
184
        result_ptr_in.remove_old(builder, block);
×
185

186
        builder.remove_node(block, *this);
×
187
        assert(dataflow.nodes().size() == 0 && "At expand time, no other nodes may be in the same graph");
×
188
        builder.remove_child(parent, index + 1);
×
189

190
        return true;
×
191
    } else {
1✔
192
        // GPU implementation of batchnorm:
193
        // Move sqrt and division into the innermost loop to enable more parallelism.
194
        auto& dataflow = this->get_parent();
1✔
195
        auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
196

197
        auto& parent = static_cast<structured_control_flow::Sequence&>(*block.get_parent());
1✔
198
        int index = parent.index(block);
1✔
199
        auto& transition = parent.at(index).second;
1✔
200

201
        auto batch_in = find_usable_input_access_node(dataflow, *this, "Batch");
1✔
202
        auto& data_type = batch_in.memlet->base_type();
1✔
203
        types::Scalar scalar_type(data_type.primitive_type());
1✔
204
        types::Tensor tensor_1d(scalar_type, {num_features()}, {symbolic::one()});
1✔
205
        std::string temp_var_prefix = "_batchn_tmp";
1✔
206
        int tmp_idx = 0;
1✔
207
        auto var_in = find_usable_input_access_node(dataflow, *this, "Var");
1✔
208
        auto e_in = find_usable_input_access_node(dataflow, *this, "E");
1✔
209
        auto gamma_in = find_usable_input_access_node(dataflow, *this, "Gamma");
1✔
210
        auto beta_in = find_usable_input_access_node(dataflow, *this, "Beta");
1✔
211
        auto result_ptr_in = find_usable_input_access_node(dataflow, *this, "B_out");
1✔
212
        auto eps_in = find_usable_input_access_node(dataflow, *this, "epsilon");
1✔
213

214
        auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), debug_info());
1✔
215

216
        auto loop_dims = create_maps(builder, layout_.shape(), new_sequence);
1✔
217

218
        auto& c_dim = loop_dims.at(1);
1✔
219
        std::vector<symbolic::Expression> c_subset{c_dim.indvar};
1✔
220

221
        auto& innermost_dim = loop_dims.at(layout_.dims() - 1);
1✔
222

223
        std::vector<symbolic::Expression> innermost_subset;
1✔
224
        for (auto& builder_map_dim : loop_dims) {
4✔
225
            innermost_subset.push_back(builder_map_dim.indvar);
4✔
226
        }
4✔
227

228
        auto& innermost_block = builder.add_block(innermost_dim.seq);
1✔
229

230
        // Access nodes
231
        auto& x_in = builder.add_access(innermost_block, batch_in.name);
1✔
232
        auto& var_elem_in = builder.add_access(innermost_block, var_in.name);
1✔
233
        data_flow::AccessNode& epsilon_const = eps_in.is_const
1✔
234
                                                   ? builder.add_constant(innermost_block, eps_in.name, scalar_type)
1✔
235
                                                   : builder.add_access(innermost_block, eps_in.name);
1✔
236
        auto& e_elem_in = builder.add_access(innermost_block, e_in.name);
1✔
237
        auto& gamma_elem_in = builder.add_access(innermost_block, gamma_in.name);
1✔
238
        auto& beta_elem_in = builder.add_access(innermost_block, beta_in.name);
1✔
239
        auto& result_ptr_out_elem = builder.add_access(innermost_block, result_ptr_in.name);
1✔
240

241
        // var[c] + eps
242
        auto& add_eps_op =
1✔
243
            builder.add_tasklet(innermost_block, data_flow::fp_add, "_out", {"var", "eps"}, debug_info());
1✔
244
        builder.add_computational_memlet(innermost_block, var_elem_in, add_eps_op, "var", c_subset, tensor_1d);
1✔
245
        builder.add_computational_memlet(innermost_block, epsilon_const, add_eps_op, "eps", {}, scalar_type);
1✔
246
        auto tmp_eps_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
247
        auto& tmp_eps = builder.add_access(innermost_block, tmp_eps_name);
1✔
248
        builder.add_computational_memlet(innermost_block, add_eps_op, "_out", tmp_eps, {}, scalar_type);
1✔
249

250
        // sqrt(var[c] + eps)
251
        auto tmp_sqrt_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
252
        auto& tmp_sqrt = builder.add_access(innermost_block, tmp_sqrt_name);
1✔
253
        auto& sqrt_op = builder.add_library_node<
1✔
254
            cmath::CMathNode>(innermost_block, debug_info(), cmath::CMathFunction::sqrt, data_type.primitive_type());
1✔
255
        builder.add_computational_memlet(innermost_block, tmp_eps, sqrt_op, "_in1", {}, scalar_type);
1✔
256
        builder.add_computational_memlet(innermost_block, sqrt_op, "_out", tmp_sqrt, {}, scalar_type);
1✔
257

258
        // 1.0 / sqrt(var[c] + eps)
259
        auto& one_const = builder.add_constant(innermost_block, "1.0", scalar_type);
1✔
260
        auto& div_op = builder.add_tasklet(innermost_block, data_flow::fp_div, "_out", {"one", "sqrt"});
1✔
261
        builder.add_computational_memlet(innermost_block, one_const, div_op, "one", {}, scalar_type);
1✔
262
        builder.add_computational_memlet(innermost_block, tmp_sqrt, div_op, "sqrt", {}, scalar_type);
1✔
263
        auto interm_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
264
        auto& interm_store = builder.add_access(innermost_block, interm_name);
1✔
265
        builder.add_computational_memlet(innermost_block, div_op, "_out", interm_store, {}, scalar_type);
1✔
266

267
        // x - e[c]
268
        auto& sub_op = builder.add_tasklet(innermost_block, data_flow::fp_sub, "_out", {"x", "e"}, debug_info());
1✔
269
        builder.add_computational_memlet(innermost_block, x_in, sub_op, "x", innermost_subset, data_type);
1✔
270
        builder.add_computational_memlet(innermost_block, e_elem_in, sub_op, "e", c_subset, tensor_1d);
1✔
271
        auto tmp_sub_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
272
        auto& tmp_sub = builder.add_access(innermost_block, tmp_sub_name);
1✔
273
        builder.add_computational_memlet(innermost_block, sub_op, "_out", tmp_sub, {}, scalar_type);
1✔
274

275
        // (x - e[c]) * (1/sqrt(var[c]+eps))
276
        auto& mul_interm_op =
1✔
277
            builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"num", "den"}, debug_info());
1✔
278
        builder.add_computational_memlet(innermost_block, tmp_sub, mul_interm_op, "num", {}, scalar_type);
1✔
279
        builder.add_computational_memlet(innermost_block, interm_store, mul_interm_op, "den", {}, scalar_type);
1✔
280
        auto tmp_interm = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
281
        auto& tmp_mul_interm = builder.add_access(innermost_block, tmp_interm);
1✔
282
        builder.add_computational_memlet(innermost_block, mul_interm_op, "_out", tmp_mul_interm, {}, scalar_type);
1✔
283

284
        // * gamma[c]
285
        auto& mul_gamma_op =
1✔
286
            builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"frac", "g"}, debug_info());
1✔
287
        builder.add_computational_memlet(innermost_block, tmp_mul_interm, mul_gamma_op, "frac", {}, scalar_type);
1✔
288
        builder.add_computational_memlet(innermost_block, gamma_elem_in, mul_gamma_op, "g", c_subset, tensor_1d);
1✔
289
        auto tmp_gamma = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
290
        auto& tmp_mul_gamma = builder.add_access(innermost_block, tmp_gamma);
1✔
291
        builder.add_computational_memlet(innermost_block, mul_gamma_op, "_out", tmp_mul_gamma, {}, scalar_type);
1✔
292

293
        // + beta[c]
294
        auto& add_beta_op = builder.add_tasklet(innermost_block, data_flow::fp_add, "_out", {"_in", "b"}, debug_info());
1✔
295
        builder.add_computational_memlet(innermost_block, tmp_mul_gamma, add_beta_op, "_in", {}, scalar_type);
1✔
296
        builder.add_computational_memlet(innermost_block, beta_elem_in, add_beta_op, "b", c_subset, tensor_1d);
1✔
297
        builder.add_computational_memlet(
1✔
298
            innermost_block, add_beta_op, "_out", result_ptr_out_elem, innermost_subset, data_type
1✔
299
        );
1✔
300

301
        batch_in.remove_old(builder, block);
1✔
302
        var_in.remove_old(builder, block);
1✔
303
        e_in.remove_old(builder, block);
1✔
304
        eps_in.remove_old(builder, block);
1✔
305
        gamma_in.remove_old(builder, block);
1✔
306
        beta_in.remove_old(builder, block);
1✔
307
        result_ptr_in.remove_old(builder, block);
1✔
308

309
        builder.remove_node(block, *this);
1✔
310
        assert(dataflow.nodes().size() == 0 && "At expand time, no other nodes may be in the same graph");
1✔
311
        builder.remove_child(parent, index + 1);
1✔
312

313
        return true;
1✔
314
    }
1✔
315
}
1✔
316

317
symbolic::Expression BatchNormNode::flop() const {
×
318
    auto inner_elems = symbolic::mul(layout_.get_dim_innermost(0), layout_.get_dim_innermost(1));
×
319
    auto outer_elems = symbolic::mul(layout_.shape().at(0), layout_.shape().at(1));
×
320

321
    // (x-e) * sqrt_pre_calc * g + b = 4 flops
322
    auto inner_flops = symbolic::mul(symbolic::integer(4), inner_elems);
×
323
    // sqrt_pre_calc = 1/sqrt(var + eps) // 3 flops
324
    auto outer_flops = symbolic::mul(symbolic::add(inner_flops, symbolic::integer(3)), outer_elems);
×
325
    return outer_flops;
×
326
}
×
327

328
data_flow::PointerAccessType BatchNormNode::pointer_access_type(int input_idx) const {
×
329
    if (input_idx >= 0 && input_idx <= 4) {
×
330
        return data_flow::PointerAccessMeta::create_read_only(symbolic::__nullptr__(), true);
×
331
    } else if (input_idx == 6) {
×
332
        return data_flow::PointerAccessMeta::create_full_write_only(symbolic::__nullptr__(), true);
×
333
    } else {
×
334
        return TensorNode::pointer_access_type(input_idx);
×
335
    }
×
336
}
×
337

338
nlohmann::json BatchNormNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
339
    auto& node = static_cast<const BatchNormNode&>(library_node);
×
340
    nlohmann::json j;
×
341

342
    j["code"] = node.code().value();
×
343

344
    node.batch_layout().serialize_to_json(j["batch_layout"]);
×
345

346
    j["batch_quant"] = node.quantization();
×
347

348
    return j;
×
349
}
×
350

351
data_flow::LibraryNode& BatchNormNodeSerializer::deserialize(
352
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
353
) {
×
354
    auto layout = TensorLayout::deserialize_from_json(j.at("batch_layout"));
×
355
    auto quant = j.at("batch_quant").get<types::PrimitiveType>();
×
356

357
    serializer::JSONSerializer serializer;
×
358
    auto deb_info = serializer.json_to_debug_info(j.at("debug_info"));
×
359

360
    return builder.add_library_node<BatchNormNode>(parent, deb_info, layout, quant);
×
361
}
×
362

363
} // namespace sdfg::math::tensor
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc