• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 24391284733

14 Apr 2026 09:22AM UTC coverage: 64.338%. First build
24391284733

Pull #677

github

web-flow
Merge b26310f79 into 4647dcb0f
Pull Request #677: Fully Elementwise Batchnorm for GPUs

97 of 195 new or added lines in 1 file covered. (49.74%)

30556 of 47493 relevant lines covered (64.34%)

582.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

55.59
/sdfg/src/data_flow/library_nodes/math/tensor/batchnorm_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/batchnorm_node.h"
2

3
#include "sdfg/analysis/scope_analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5
#include "sdfg/data_flow/access_node.h"
6
#include "sdfg/data_flow/library_nodes/math/cmath/cmath_node.h"
7
#include "sdfg/structured_control_flow/block.h"
8
#include "sdfg/structured_control_flow/structured_loop.h"
9

10
namespace sdfg::math::tensor {
11

12

13
BatchNormNode::BatchNormNode(
14
    size_t element_id,
15
    const DebugInfo& debug_info,
16
    graph::Vertex vertex,
17
    data_flow::DataFlowGraph& parent,
18
    TensorLayout layout,
19
    types::PrimitiveType quantization,
20
    data_flow::ImplementationType impl_type
21
)
22
    : TensorNode(
22✔
23
          element_id,
22✔
24
          debug_info,
22✔
25
          vertex,
22✔
26
          parent,
22✔
27
          LibraryNodeType_BatchNorm,
22✔
28
          {},
22✔
29
          {"Batch", "Var", "E", "Gamma", "Beta", "epsilon", "B_out"},
22✔
30
          std::move(impl_type)
22✔
31
      ),
22✔
32
      layout_(std::move(layout)), quantization_(quantization) {}
22✔
33

34
symbolic::SymbolSet BatchNormNode::symbols() const {
24✔
35
    symbolic::SymbolSet syms;
24✔
36
    layout_.collect_symbols(syms);
24✔
37
    return syms;
24✔
38
}
24✔
39

40
types::PrimitiveType BatchNormNode::quantization() const { return quantization_; }
×
41

42
void BatchNormNode::set_quantization(const types::PrimitiveType quant) { quantization_ = quant; }
×
43

44
void BatchNormNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
45
    layout_.replace_symbols(old_expression, new_expression);
×
46
}
×
47

48
std::unique_ptr<data_flow::DataFlowNode> BatchNormNode::
49
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
50
    return std::unique_ptr<data_flow::DataFlowNode>(new BatchNormNode(
×
51
        element_id, debug_info(), vertex, parent, this->layout_, this->quantization_, this->implementation_type_
×
52
    ));
×
53
}
×
54

55
std::string BatchNormNode::toStr() const { return "BatchNorm(" + layout_.toStr() + ")"; }
×
56

57
struct InputContainerInfo {
58
    std::string name;
59
    bool is_const = false;
60
    const data_flow::Memlet* memlet;
61
    const data_flow::AccessNode* access_to_remove = nullptr;
62

63
    void remove_old(builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& block) const {
7✔
64
        if (memlet) {
7✔
65
            builder.remove_memlet(block, *memlet);
7✔
66
        }
7✔
67
        if (access_to_remove) {
7✔
68
            builder.remove_node(block, *access_to_remove);
7✔
69
        }
7✔
70
    }
7✔
71
};
72

73
InputContainerInfo find_usable_input_access_node(
74
    data_flow::DataFlowGraph& dataflow, data_flow::LibraryNode& node, const std::string& input_conn
75
) {
7✔
76
    auto* edge = dataflow.in_edge_for_connector(node, input_conn);
7✔
77
    if (!edge) {
7✔
78
        throw InvalidSDFGException(node.toStr() + " requires input on " + input_conn);
×
79
    }
×
80
    auto* access_node = dynamic_cast<const data_flow::AccessNode*>(&edge->src());
7✔
81
    if (!access_node) {
7✔
82
        throw InvalidSDFGException(node.toStr() + " requires input on " + input_conn + " to be an access node");
×
83
    }
×
84

85
    return {
7✔
86
        .name = access_node->data(),
7✔
87
        .is_const = !!dynamic_cast<const data_flow::ConstantNode*>(&edge->src()),
7✔
88
        .memlet = edge,
7✔
89
        .access_to_remove = access_node
7✔
90
    };
7✔
91
}
7✔
92

93
struct BuilderMapDim {
94
    symbolic::Expression indvar;
95
    structured_control_flow::StructuredLoop& loop;
96
    structured_control_flow::Sequence& seq;
97
};
98

99
std::vector<BuilderMapDim> create_maps(
100
    builder::StructuredSDFGBuilder& builder,
101
    const std::vector<symbolic::Expression>& sizes,
102
    structured_control_flow::Sequence& block
103
) {
1✔
104
    std::vector<BuilderMapDim> scopes;
1✔
105

106
    Sequence* last_scope = &block;
1✔
107

108
    for (size_t i = 0; i < sizes.size(); i++) {
5✔
109
        std::string indvar_str = builder.find_new_name("_i");
4✔
110
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
4✔
111

112
        auto indvar = symbolic::symbol(indvar_str);
4✔
113
        auto init = symbolic::zero();
4✔
114
        auto update = symbolic::add(indvar, symbolic::one());
4✔
115
        auto condition = symbolic::Lt(indvar, sizes.at(i));
4✔
116
        auto& last_map = builder.add_map(
4✔
117
            *last_scope,
4✔
118
            indvar,
4✔
119
            condition,
4✔
120
            init,
4✔
121
            update,
4✔
122
            structured_control_flow::ScheduleType_Sequential::create(),
4✔
123
            {},
4✔
124
            block.debug_info()
4✔
125
        );
4✔
126
        auto& seq = last_map.root();
4✔
127
        last_scope = &seq;
4✔
128

129
        scopes.push_back(
4✔
130
            {.indvar = indvar, .loop = dynamic_cast<structured_control_flow::StructuredLoop&>(last_map), .seq = seq}
4✔
131
        );
4✔
132
    }
4✔
133

134
    return scopes;
1✔
135
}
1✔
136

137
std::string
138
create_temp_var(builder::StructuredSDFGBuilder& builder, const std::string& prefix, int gen, const types::IType& type) {
6✔
139
    std::string n = prefix + "_" + std::to_string(gen);
6✔
140
    auto name = builder.find_new_name(n);
6✔
141
    builder.add_container(name, type);
6✔
142
    return name;
6✔
143
}
6✔
144

145
bool BatchNormNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
1✔
146
    // CPU implementation of batchnorm:
147
    if (false) {
1✔
NEW
148
        auto& dataflow = this->get_parent();
×
NEW
149
        auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
150

NEW
151
        auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
NEW
152
        auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
NEW
153
        int index = parent.index(block);
×
NEW
154
        auto& transition = parent.at(index).second;
×
155

NEW
156
        auto batch_in = find_usable_input_access_node(dataflow, *this, "Batch");
×
NEW
157
        auto& data_type = batch_in.memlet->base_type();
×
NEW
158
        types::Scalar scalar_type(data_type.primitive_type());
×
NEW
159
        types::Tensor tensor_1d(scalar_type, {num_features()}, {symbolic::one()}); // TODO verify / get from inputs
×
NEW
160
        std::string temp_var_prefix = "_batchn_tmp";
×
NEW
161
        int tmp_idx = 0;
×
NEW
162
        auto var_in = find_usable_input_access_node(dataflow, *this, "Var");
×
NEW
163
        auto e_in = find_usable_input_access_node(dataflow, *this, "E");
×
NEW
164
        auto gamma_in = find_usable_input_access_node(dataflow, *this, "Gamma");
×
NEW
165
        auto beta_in = find_usable_input_access_node(dataflow, *this, "Beta");
×
NEW
166
        auto result_ptr_in = find_usable_input_access_node(dataflow, *this, "B_out");
×
NEW
167
        auto eps_in = find_usable_input_access_node(dataflow, *this, "epsilon");
×
168

NEW
169
        auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), debug_info());
×
170

NEW
171
        auto loop_dims = create_maps(builder, layout_.shape(), new_sequence);
×
172

NEW
173
        auto& c_dim = loop_dims.at(1);
×
NEW
174
        std::vector<symbolic::Expression> c_subset{c_dim.indvar};
×
NEW
175
        auto interm_name = builder.find_new_name("_b_sqrt_div");
×
NEW
176
        builder.add_container(interm_name, scalar_type);
×
NEW
177
        auto& inter_block = builder.add_block_before(
×
NEW
178
            c_dim.seq, static_cast<structured_control_flow::ControlFlowNode&>(loop_dims.at(2).loop), {}, DebugInfo()
×
NEW
179
        );
×
180

NEW
181
        auto& var_elem_in = builder.add_access(inter_block, var_in.name);
×
NEW
182
        data_flow::AccessNode& epsilon_const = eps_in.is_const
×
NEW
183
                                                   ? builder.add_constant(inter_block, eps_in.name, scalar_type)
×
NEW
184
                                                   : builder.add_access(inter_block, eps_in.name);
×
185

NEW
186
        auto& add_eps_op = builder.add_tasklet(inter_block, data_flow::fp_add, "_out", {"var", "eps"}, debug_info());
×
187

NEW
188
        builder.add_computational_memlet(inter_block, var_elem_in, add_eps_op, "var", c_subset, tensor_1d);
×
NEW
189
        builder.add_computational_memlet(inter_block, epsilon_const, add_eps_op, "eps", {}, scalar_type);
×
190

NEW
191
        auto tmp_eps_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
NEW
192
        auto& tmp_eps = builder.add_access(inter_block, tmp_eps_name);
×
193

NEW
194
        builder.add_computational_memlet(inter_block, add_eps_op, "_out", tmp_eps, {}, scalar_type);
×
195

NEW
196
        auto tmp_sqrt_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
NEW
197
        auto& tmp_sqrt = builder.add_access(inter_block, tmp_sqrt_name);
×
198

NEW
199
        auto& sqrt_op = builder.add_library_node<
×
NEW
200
            cmath::CMathNode>(inter_block, debug_info(), cmath::CMathFunction::sqrt, data_type.primitive_type());
×
201

NEW
202
        builder.add_computational_memlet(inter_block, tmp_eps, sqrt_op, "_in1", {}, scalar_type);
×
203

NEW
204
        builder.add_computational_memlet(inter_block, sqrt_op, "_out", tmp_sqrt, {}, scalar_type);
×
205

NEW
206
        auto& one_const = builder.add_constant(inter_block, "1.0", scalar_type);
×
NEW
207
        auto& div_op = builder.add_tasklet(inter_block, data_flow::fp_div, "_out", {"one", "sqrt"});
×
NEW
208
        builder.add_computational_memlet(inter_block, one_const, div_op, "one", {}, scalar_type);
×
NEW
209
        builder.add_computational_memlet(inter_block, tmp_sqrt, div_op, "sqrt", {}, scalar_type);
×
210

NEW
211
        auto& interm_store = builder.add_access(inter_block, interm_name);
×
NEW
212
        builder.add_computational_memlet(inter_block, div_op, "_out", interm_store, {}, scalar_type);
×
213

NEW
214
        auto& innermost_dim = loop_dims.at(layout_.dims() - 1);
×
215

NEW
216
        std::vector<symbolic::Expression> innermost_subset;
×
NEW
217
        for (auto& builder_map_dim : loop_dims) {
×
NEW
218
            innermost_subset.push_back(builder_map_dim.indvar);
×
NEW
219
        }
×
220

NEW
221
        auto& innermost_block = builder.add_block(innermost_dim.seq);
×
NEW
222
        auto& x_in = builder.add_access(innermost_block, batch_in.name);
×
NEW
223
        auto& interm_in = builder.add_access(innermost_block, interm_name);
×
NEW
224
        auto& e_elem_in = builder.add_access(innermost_block, e_in.name);
×
NEW
225
        auto& gamma_elem_in = builder.add_access(innermost_block, gamma_in.name);
×
NEW
226
        auto& beta_elem_in = builder.add_access(innermost_block, beta_in.name);
×
227

NEW
228
        auto& result_ptr_out_elem = builder.add_access(innermost_block, result_ptr_in.name);
×
229

NEW
230
        auto& sub_op = builder.add_tasklet(innermost_block, data_flow::fp_sub, "_out", {"x", "e"}, debug_info());
×
231

NEW
232
        builder.add_computational_memlet(innermost_block, x_in, sub_op, "x", innermost_subset, data_type);
×
NEW
233
        builder.add_computational_memlet(innermost_block, e_elem_in, sub_op, "e", c_subset, tensor_1d);
×
NEW
234
        auto tmp_sub_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
NEW
235
        auto& tmp_sub = builder.add_access(innermost_block, tmp_sub_name);
×
NEW
236
        builder.add_computational_memlet(innermost_block, sub_op, "_out", tmp_sub, {}, scalar_type);
×
237

NEW
238
        auto& mul_interm_op =
×
NEW
239
            builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"num", "den"}, debug_info());
×
240

NEW
241
        builder.add_computational_memlet(innermost_block, tmp_sub, mul_interm_op, "num", {}, scalar_type);
×
NEW
242
        builder.add_computational_memlet(innermost_block, interm_in, mul_interm_op, "den", {}, scalar_type);
×
NEW
243
        auto tmp_interm = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
NEW
244
        auto& tmp_mul_interm = builder.add_access(innermost_block, tmp_interm);
×
NEW
245
        builder.add_computational_memlet(innermost_block, mul_interm_op, "_out", tmp_mul_interm, {}, scalar_type);
×
246

NEW
247
        auto& mul_gamma_op =
×
NEW
248
            builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"frac", "g"}, debug_info());
×
249

NEW
250
        builder.add_computational_memlet(innermost_block, tmp_mul_interm, mul_gamma_op, "frac", {}, scalar_type);
×
NEW
251
        builder.add_computational_memlet(innermost_block, gamma_elem_in, mul_gamma_op, "g", c_subset, tensor_1d);
×
252

NEW
253
        auto tmp_gamma = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
×
NEW
254
        auto& tmp_mul_gamma = builder.add_access(innermost_block, tmp_gamma);
×
NEW
255
        builder.add_computational_memlet(innermost_block, mul_gamma_op, "_out", tmp_mul_gamma, {}, scalar_type);
×
256

NEW
257
        auto& add_beta_op = builder.add_tasklet(innermost_block, data_flow::fp_add, "_out", {"_in", "b"}, debug_info());
×
258

NEW
259
        builder.add_computational_memlet(innermost_block, tmp_mul_gamma, add_beta_op, "_in", {}, scalar_type);
×
NEW
260
        builder.add_computational_memlet(innermost_block, beta_elem_in, add_beta_op, "b", c_subset, tensor_1d);
×
NEW
261
        builder.add_computational_memlet(
×
NEW
262
            innermost_block, add_beta_op, "_out", result_ptr_out_elem, innermost_subset, data_type
×
NEW
263
        );
×
264

NEW
265
        batch_in.remove_old(builder, block);
×
NEW
266
        var_in.remove_old(builder, block);
×
NEW
267
        e_in.remove_old(builder, block);
×
NEW
268
        eps_in.remove_old(builder, block);
×
NEW
269
        gamma_in.remove_old(builder, block);
×
NEW
270
        beta_in.remove_old(builder, block);
×
NEW
271
        result_ptr_in.remove_old(builder, block);
×
272

NEW
273
        builder.remove_node(block, *this);
×
NEW
274
        assert(dataflow.nodes().size() == 0 && "At expand time, no other nodes may be in the same graph");
×
NEW
275
        builder.remove_child(parent, index + 1);
×
276

NEW
277
        return true;
×
278
    } else {
1✔
279
        // GPU implementation of batchnorm:
280
        // Move sqrt and division into the innermost loop to enable more parallelism.
281
        auto& dataflow = this->get_parent();
1✔
282
        auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
1✔
283

284
        auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
285
        auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
1✔
286
        int index = parent.index(block);
1✔
287
        auto& transition = parent.at(index).second;
1✔
288

289
        auto batch_in = find_usable_input_access_node(dataflow, *this, "Batch");
1✔
290
        auto& data_type = batch_in.memlet->base_type();
1✔
291
        types::Scalar scalar_type(data_type.primitive_type());
1✔
292
        types::Tensor tensor_1d(scalar_type, {num_features()}, {symbolic::one()});
1✔
293
        std::string temp_var_prefix = "_batchn_tmp";
1✔
294
        int tmp_idx = 0;
1✔
295
        auto var_in = find_usable_input_access_node(dataflow, *this, "Var");
1✔
296
        auto e_in = find_usable_input_access_node(dataflow, *this, "E");
1✔
297
        auto gamma_in = find_usable_input_access_node(dataflow, *this, "Gamma");
1✔
298
        auto beta_in = find_usable_input_access_node(dataflow, *this, "Beta");
1✔
299
        auto result_ptr_in = find_usable_input_access_node(dataflow, *this, "B_out");
1✔
300
        auto eps_in = find_usable_input_access_node(dataflow, *this, "epsilon");
1✔
301

302
        auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), debug_info());
1✔
303

304
        auto loop_dims = create_maps(builder, layout_.shape(), new_sequence);
1✔
305

306
        auto& c_dim = loop_dims.at(1);
1✔
307
        std::vector<symbolic::Expression> c_subset{c_dim.indvar};
1✔
308

309
        auto& innermost_dim = loop_dims.at(layout_.dims() - 1);
1✔
310

311
        std::vector<symbolic::Expression> innermost_subset;
1✔
312
        for (auto& builder_map_dim : loop_dims) {
4✔
313
            innermost_subset.push_back(builder_map_dim.indvar);
4✔
314
        }
4✔
315

316
        auto& innermost_block = builder.add_block(innermost_dim.seq);
1✔
317

318
        // Access nodes
319
        auto& x_in = builder.add_access(innermost_block, batch_in.name);
1✔
320
        auto& var_elem_in = builder.add_access(innermost_block, var_in.name);
1✔
321
        data_flow::AccessNode& epsilon_const = eps_in.is_const
1✔
322
                                                   ? builder.add_constant(innermost_block, eps_in.name, scalar_type)
1✔
323
                                                   : builder.add_access(innermost_block, eps_in.name);
1✔
324
        auto& e_elem_in = builder.add_access(innermost_block, e_in.name);
1✔
325
        auto& gamma_elem_in = builder.add_access(innermost_block, gamma_in.name);
1✔
326
        auto& beta_elem_in = builder.add_access(innermost_block, beta_in.name);
1✔
327
        auto& result_ptr_out_elem = builder.add_access(innermost_block, result_ptr_in.name);
1✔
328

329
        // var[c] + eps
330
        auto& add_eps_op =
1✔
331
            builder.add_tasklet(innermost_block, data_flow::fp_add, "_out", {"var", "eps"}, debug_info());
1✔
332
        builder.add_computational_memlet(innermost_block, var_elem_in, add_eps_op, "var", c_subset, tensor_1d);
1✔
333
        builder.add_computational_memlet(innermost_block, epsilon_const, add_eps_op, "eps", {}, scalar_type);
1✔
334
        auto tmp_eps_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
335
        auto& tmp_eps = builder.add_access(innermost_block, tmp_eps_name);
1✔
336
        builder.add_computational_memlet(innermost_block, add_eps_op, "_out", tmp_eps, {}, scalar_type);
1✔
337

338
        // sqrt(var[c] + eps)
339
        auto tmp_sqrt_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
340
        auto& tmp_sqrt = builder.add_access(innermost_block, tmp_sqrt_name);
1✔
341
        auto& sqrt_op = builder.add_library_node<
1✔
342
            cmath::CMathNode>(innermost_block, debug_info(), cmath::CMathFunction::sqrt, data_type.primitive_type());
1✔
343
        builder.add_computational_memlet(innermost_block, tmp_eps, sqrt_op, "_in1", {}, scalar_type);
1✔
344
        builder.add_computational_memlet(innermost_block, sqrt_op, "_out", tmp_sqrt, {}, scalar_type);
1✔
345

346
        // 1.0 / sqrt(var[c] + eps)
347
        auto& one_const = builder.add_constant(innermost_block, "1.0", scalar_type);
1✔
348
        auto& div_op = builder.add_tasklet(innermost_block, data_flow::fp_div, "_out", {"one", "sqrt"});
1✔
349
        builder.add_computational_memlet(innermost_block, one_const, div_op, "one", {}, scalar_type);
1✔
350
        builder.add_computational_memlet(innermost_block, tmp_sqrt, div_op, "sqrt", {}, scalar_type);
1✔
351
        auto interm_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
352
        auto& interm_store = builder.add_access(innermost_block, interm_name);
1✔
353
        builder.add_computational_memlet(innermost_block, div_op, "_out", interm_store, {}, scalar_type);
1✔
354

355
        // x - e[c]
356
        auto& sub_op = builder.add_tasklet(innermost_block, data_flow::fp_sub, "_out", {"x", "e"}, debug_info());
1✔
357
        builder.add_computational_memlet(innermost_block, x_in, sub_op, "x", innermost_subset, data_type);
1✔
358
        builder.add_computational_memlet(innermost_block, e_elem_in, sub_op, "e", c_subset, tensor_1d);
1✔
359
        auto tmp_sub_name = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
360
        auto& tmp_sub = builder.add_access(innermost_block, tmp_sub_name);
1✔
361
        builder.add_computational_memlet(innermost_block, sub_op, "_out", tmp_sub, {}, scalar_type);
1✔
362

363
        // (x - e[c]) * (1/sqrt(var[c]+eps))
364
        auto& mul_interm_op =
1✔
365
            builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"num", "den"}, debug_info());
1✔
366
        builder.add_computational_memlet(innermost_block, tmp_sub, mul_interm_op, "num", {}, scalar_type);
1✔
367
        builder.add_computational_memlet(innermost_block, interm_store, mul_interm_op, "den", {}, scalar_type);
1✔
368
        auto tmp_interm = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
369
        auto& tmp_mul_interm = builder.add_access(innermost_block, tmp_interm);
1✔
370
        builder.add_computational_memlet(innermost_block, mul_interm_op, "_out", tmp_mul_interm, {}, scalar_type);
1✔
371

372
        // * gamma[c]
373
        auto& mul_gamma_op =
1✔
374
            builder.add_tasklet(innermost_block, data_flow::fp_mul, "_out", {"frac", "g"}, debug_info());
1✔
375
        builder.add_computational_memlet(innermost_block, tmp_mul_interm, mul_gamma_op, "frac", {}, scalar_type);
1✔
376
        builder.add_computational_memlet(innermost_block, gamma_elem_in, mul_gamma_op, "g", c_subset, tensor_1d);
1✔
377
        auto tmp_gamma = create_temp_var(builder, temp_var_prefix, tmp_idx++, scalar_type);
1✔
378
        auto& tmp_mul_gamma = builder.add_access(innermost_block, tmp_gamma);
1✔
379
        builder.add_computational_memlet(innermost_block, mul_gamma_op, "_out", tmp_mul_gamma, {}, scalar_type);
1✔
380

381
        // + beta[c]
382
        auto& add_beta_op = builder.add_tasklet(innermost_block, data_flow::fp_add, "_out", {"_in", "b"}, debug_info());
1✔
383
        builder.add_computational_memlet(innermost_block, tmp_mul_gamma, add_beta_op, "_in", {}, scalar_type);
1✔
384
        builder.add_computational_memlet(innermost_block, beta_elem_in, add_beta_op, "b", c_subset, tensor_1d);
1✔
385
        builder.add_computational_memlet(
1✔
386
            innermost_block, add_beta_op, "_out", result_ptr_out_elem, innermost_subset, data_type
1✔
387
        );
1✔
388

389
        batch_in.remove_old(builder, block);
1✔
390
        var_in.remove_old(builder, block);
1✔
391
        e_in.remove_old(builder, block);
1✔
392
        eps_in.remove_old(builder, block);
1✔
393
        gamma_in.remove_old(builder, block);
1✔
394
        beta_in.remove_old(builder, block);
1✔
395
        result_ptr_in.remove_old(builder, block);
1✔
396

397
        builder.remove_node(block, *this);
1✔
398
        assert(dataflow.nodes().size() == 0 && "At expand time, no other nodes may be in the same graph");
1✔
399
        builder.remove_child(parent, index + 1);
1✔
400

401
        return true;
1✔
402
    }
1✔
403
}
1✔
404

405
symbolic::Expression BatchNormNode::flop() const {
×
406
    auto inner_elems = symbolic::mul(layout_.get_dim_innermost(0), layout_.get_dim_innermost(1));
×
407
    auto outer_elems = symbolic::mul(layout_.shape().at(0), layout_.shape().at(1));
×
408

409
    // (x-e) * sqrt_pre_calc * g + b = 4 flops
410
    auto inner_flops = symbolic::mul(symbolic::integer(4), inner_elems);
×
411
    // sqrt_pre_calc = 1/sqrt(var + eps) // 3 flops
412
    auto outer_flops = symbolic::mul(symbolic::add(inner_flops, symbolic::integer(3)), outer_elems);
×
413
    return outer_flops;
×
414
}
×
415

416
nlohmann::json BatchNormNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
417
    auto& node = static_cast<const BatchNormNode&>(library_node);
×
418
    nlohmann::json j;
×
419

420
    j["code"] = node.code().value();
×
421

422
    node.batch_layout().serialize_to_json(j["batch_layout"]);
×
423

424
    j["batch_quant"] = node.quantization();
×
425

426
    return j;
×
427
}
×
428

429
data_flow::LibraryNode& BatchNormNodeSerializer::deserialize(
430
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
431
) {
×
432
    auto layout = TensorLayout::deserialize_from_json(j.at("batch_layout"));
×
433
    auto quant = j.at("batch_quant").get<types::PrimitiveType>();
×
434

435
    serializer::JSONSerializer serializer;
×
436
    auto deb_info = serializer.json_to_debug_info(j.at("debug_info"));
×
437

438
    return builder.add_library_node<BatchNormNode>(parent, deb_info, layout, quant);
×
439
}
×
440

441
} // namespace sdfg::math::tensor
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc