• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23052075684

13 Mar 2026 01:03PM UTC coverage: 63.794%. First build
23052075684

Pull #582

github

web-flow
Merge f1049120e into 9bccac573
Pull Request #582: [MLIR] Pooling layer Support

324 of 466 new or added lines in 3 files covered. (69.53%)

25600 of 40129 relevant lines covered (63.79%)

396.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.23
/sdfg/src/data_flow/library_nodes/math/tensor/pooling_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/pooling_node.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/analysis/scope_analysis.h"
5
#include "sdfg/builder/structured_sdfg_builder.h"
6
#include "sdfg/data_flow/library_nodes/math/cmath/cmath_node.h"
7
#include "sdfg/data_flow/library_nodes/math/tensor/tensor_node.h"
8
#include "sdfg/types/type.h"
9

10
namespace sdfg {
11
namespace math {
12
namespace tensor {
13

14
PoolingNode::PoolingNode(
15
    size_t element_id,
16
    const DebugInfo& debug_info,
17
    const graph::Vertex vertex,
18
    data_flow::DataFlowGraph& parent,
19
    PoolingMode mode,
20
    const std::vector<symbolic::Expression>& shape,
21
    const std::vector<symbolic::Expression>& kernel_shape,
22
    const std::vector<symbolic::Expression>& strides,
23
    const std::vector<symbolic::Expression>& pads,
24
    const std::vector<symbolic::Expression>& dilations
25
)
26
    : TensorNode(
15✔
27
          element_id, debug_info, vertex, parent, LibraryNodeType_Pooling, {"Y"}, {"X"}, data_flow::ImplementationType_NONE
15✔
28
      ),
15✔
29
      mode_(mode), shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations) {}
15✔
30

31
void PoolingNode::validate(const Function& function) const {
12✔
32
    TensorNode::validate(function);
12✔
33

34
    if (kernel_shape_.empty()) {
12✔
NEW
35
        throw InvalidSDFGException("PoolingNode kernel_shape cannot be empty");
×
NEW
36
    }
×
37

38
    size_t spatial_dims = kernel_shape_.size();
12✔
39

40
    if (!strides_.empty() && strides_.size() != spatial_dims) {
12✔
NEW
41
        throw InvalidSDFGException("PoolingNode strides must match kernel spatial dimensions");
×
NEW
42
    }
×
43

44
    if (!pads_.empty() && pads_.size() != 2 * spatial_dims) {
12✔
NEW
45
        throw InvalidSDFGException("PoolingNode pads must have 2 * spatial dimensions");
×
NEW
46
    }
×
47

48
    if (!dilations_.empty() && dilations_.size() != spatial_dims) {
12✔
NEW
49
        throw InvalidSDFGException("PoolingNode dilations must match kernel spatial dimensions");
×
NEW
50
    }
×
51
}
12✔
52

53
bool PoolingNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
9✔
54
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
9✔
55

56
    auto& dataflow = this->get_parent();
9✔
57
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
9✔
58
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
9✔
59
    int index = parent.index(block);
9✔
60
    auto& transition = parent.at(index).second;
9✔
61

62
    auto primitive_type = this->primitive_type(dataflow);
9✔
63
    types::Scalar scalar_type(primitive_type);
9✔
64

65
    auto in_edges = dataflow.in_edges(*this);
9✔
66
    data_flow::Memlet* x_edge = nullptr;
9✔
67
    auto in_edges_it = in_edges.begin();
9✔
68
    while (in_edges_it != in_edges.end()) {
18✔
69
        auto& edge = *in_edges_it;
9✔
70
        if (edge.dst_conn() == "X") {
9✔
71
            x_edge = &edge;
9✔
72
        }
9✔
73
        ++in_edges_it;
9✔
74
    }
9✔
75
    if (!x_edge) {
9✔
NEW
76
        return false;
×
NEW
77
    }
×
78

79
    auto& y_edge = *dataflow.out_edges(*this).begin();
9✔
80

81
    auto* x_node = static_cast<data_flow::AccessNode*>(&x_edge->src());
9✔
82
    auto* y_node = static_cast<data_flow::AccessNode*>(&y_edge.dst());
9✔
83

84
    if (!x_node || dataflow.in_degree(*x_node) != 0 || !y_node || dataflow.out_degree(*y_node) != 0) {
9✔
NEW
85
        return false;
×
NEW
86
    }
×
87

88
    size_t spatial_dims = kernel_shape_.size();
9✔
89
    if (spatial_dims == 0) {
9✔
NEW
90
        return false;
×
NEW
91
    }
×
92

93
    // Get strides (default to 1)
94
    std::vector<symbolic::Expression> strides_vec;
9✔
95
    for (size_t i = 0; i < spatial_dims; ++i) {
26✔
96
        if (i < strides_.size()) {
17✔
97
            strides_vec.push_back(strides_[i]);
17✔
98
        } else {
17✔
NEW
99
            strides_vec.push_back(symbolic::one());
×
NEW
100
        }
×
101
    }
17✔
102

103
    // Get padding (default to 0)
104
    std::vector<symbolic::Expression> pads_begin_vec, pads_end_vec;
9✔
105
    for (size_t i = 0; i < spatial_dims; ++i) {
26✔
106
        if (i < pads_.size()) {
17✔
107
            pads_begin_vec.push_back(pads_[i]);
2✔
108
        } else {
15✔
109
            pads_begin_vec.push_back(symbolic::zero());
15✔
110
        }
15✔
111
        if (spatial_dims + i < pads_.size()) {
17✔
112
            pads_end_vec.push_back(pads_[spatial_dims + i]);
2✔
113
        } else {
15✔
114
            pads_end_vec.push_back(symbolic::zero());
15✔
115
        }
15✔
116
    }
17✔
117

118
    // Get dilations (default to 1)
119
    std::vector<symbolic::Expression> dilations_vec;
9✔
120
    for (size_t i = 0; i < spatial_dims; ++i) {
26✔
121
        if (i < dilations_.size()) {
17✔
122
            dilations_vec.push_back(dilations_[i]);
2✔
123
        } else {
15✔
124
            dilations_vec.push_back(symbolic::one());
15✔
125
        }
15✔
126
    }
17✔
127

128
    auto& X_var = x_node->data();
9✔
129
    auto& Y_var = y_node->data();
9✔
130

131
    // Input shape: [N, C, D0, D1, ..., Dn]
132
    symbolic::Expression N = shape_[0];
9✔
133
    symbolic::Expression C = shape_[1];
9✔
134
    std::vector<symbolic::Expression> input_spatial_dims;
9✔
135
    for (size_t i = 0; i < spatial_dims; ++i) {
26✔
136
        input_spatial_dims.push_back(shape_[2 + i]);
17✔
137
    }
17✔
138

139
    // Output spatial dimensions
140
    std::vector<symbolic::Expression> output_spatial_dims;
9✔
141
    for (size_t i = 0; i < spatial_dims; ++i) {
26✔
142
        auto d_in = input_spatial_dims[i];
17✔
143
        auto pad = symbolic::add(pads_begin_vec[i], pads_end_vec[i]);
17✔
144
        auto dk = symbolic::mul(dilations_vec[i], symbolic::sub(kernel_shape_[i], symbolic::one()));
17✔
145
        auto num = symbolic::sub(symbolic::add(d_in, pad), symbolic::add(dk, symbolic::one()));
17✔
146
        auto d_out = symbolic::add(symbolic::div(num, strides_vec[i]), symbolic::one());
17✔
147
        output_spatial_dims.push_back(d_out);
17✔
148
    }
17✔
149

150
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
9✔
151

152
    structured_control_flow::Sequence* current_scope = &new_sequence;
9✔
153
    std::vector<symbolic::Expression> output_indices;
9✔
154
    std::vector<symbolic::Expression> output_spatial_vars;
9✔
155

156
    // Map over batch
157
    std::string n_str = builder.find_new_name("n");
9✔
158
    builder.add_container(n_str, types::Scalar(types::PrimitiveType::UInt64));
9✔
159
    auto n_var = symbolic::symbol(n_str);
9✔
160
    auto& map_n = builder.add_map(
9✔
161
        *current_scope,
9✔
162
        n_var,
9✔
163
        symbolic::Lt(n_var, N),
9✔
164
        symbolic::zero(),
9✔
165
        symbolic::add(n_var, symbolic::one()),
9✔
166
        structured_control_flow::ScheduleType_Sequential::create(),
9✔
167
        {},
9✔
168
        block.debug_info()
9✔
169
    );
9✔
170
    current_scope = &map_n.root();
9✔
171
    output_indices.push_back(n_var);
9✔
172

173
    // Map over channel
174
    std::string c_str = builder.find_new_name("c");
9✔
175
    builder.add_container(c_str, types::Scalar(types::PrimitiveType::UInt64));
9✔
176
    auto c_var = symbolic::symbol(c_str);
9✔
177
    auto& map_c = builder.add_map(
9✔
178
        *current_scope,
9✔
179
        c_var,
9✔
180
        symbolic::Lt(c_var, C),
9✔
181
        symbolic::zero(),
9✔
182
        symbolic::add(c_var, symbolic::one()),
9✔
183
        structured_control_flow::ScheduleType_Sequential::create(),
9✔
184
        {},
9✔
185
        block.debug_info()
9✔
186
    );
9✔
187
    current_scope = &map_c.root();
9✔
188
    output_indices.push_back(c_var);
9✔
189

190
    // Map over each output spatial dimension
191
    for (size_t i = 0; i < spatial_dims; ++i) {
26✔
192
        std::string od_str = builder.find_new_name("od" + std::to_string(i));
17✔
193
        builder.add_container(od_str, types::Scalar(types::PrimitiveType::UInt64));
17✔
194
        auto od_var = symbolic::symbol(od_str);
17✔
195
        auto& map_od = builder.add_map(
17✔
196
            *current_scope,
17✔
197
            od_var,
17✔
198
            symbolic::Lt(od_var, output_spatial_dims[i]),
17✔
199
            symbolic::zero(),
17✔
200
            symbolic::add(od_var, symbolic::one()),
17✔
201
            structured_control_flow::ScheduleType_Sequential::create(),
17✔
202
            {},
17✔
203
            block.debug_info()
17✔
204
        );
17✔
205
        current_scope = &map_od.root();
17✔
206
        output_indices.push_back(od_var);
17✔
207
        output_spatial_vars.push_back(od_var);
17✔
208
    }
17✔
209

210
    // Create accumulator
211
    std::string accum_var = builder.find_new_name("_pool_accum");
9✔
212
    builder.add_container(accum_var, scalar_type);
9✔
213

214
    // Initialize accumulator
215
    std::string init_value;
9✔
216
    if (mode_ == PoolingMode::Max) {
9✔
217
        // Use -INFINITY for float, type-min for integers
218
        if (types::is_integer(primitive_type)) {
7✔
219
            switch (primitive_type) {
1✔
NEW
220
                case types::PrimitiveType::Int8:
×
NEW
221
                    init_value = "INT8_MIN";
×
NEW
222
                    break;
×
NEW
223
                case types::PrimitiveType::Int16:
×
NEW
224
                    init_value = "INT16_MIN";
×
NEW
225
                    break;
×
226
                case types::PrimitiveType::Int32:
1✔
227
                    init_value = "INT32_MIN";
1✔
228
                    break;
1✔
NEW
229
                case types::PrimitiveType::Int64:
×
NEW
230
                    init_value = "INT64_MIN";
×
NEW
231
                    break;
×
NEW
232
                default:
×
NEW
233
                    init_value = "0";
×
NEW
234
                    break;
×
235
            }
1✔
236
        } else {
6✔
237
            init_value = "-INFINITY";
6✔
238
        }
6✔
239
    } else {
7✔
240
        // Sum / Avg: init to 0
241
        init_value = types::is_integer(primitive_type) ? "0" : "0.0";
2✔
242
    }
2✔
243

244
    auto& init_block = builder.add_block(*current_scope, {}, block.debug_info());
9✔
245
    auto& accum_init = builder.add_access(init_block, accum_var, block.debug_info());
9✔
246
    auto& zero_const = builder.add_constant(init_block, init_value, scalar_type, block.debug_info());
9✔
247
    auto& init_tasklet = builder.add_tasklet(init_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
9✔
248
    builder.add_computational_memlet(init_block, zero_const, init_tasklet, "_in", {}, scalar_type, block.debug_info());
9✔
249
    builder.add_computational_memlet(init_block, init_tasklet, "_out", accum_init, {}, scalar_type, block.debug_info());
9✔
250

251
    // For loops over kernel spatial dimensions
252
    auto* loop_scope = current_scope;
9✔
253
    std::vector<symbolic::Expression> kernel_vars;
9✔
254
    for (size_t i = 0; i < spatial_dims; ++i) {
26✔
255
        std::string k_str = builder.find_new_name("k" + std::to_string(i));
17✔
256
        builder.add_container(k_str, types::Scalar(types::PrimitiveType::UInt64));
17✔
257
        auto k_var = symbolic::symbol(k_str);
17✔
258
        auto& for_k = builder.add_for(
17✔
259
            *loop_scope,
17✔
260
            k_var,
17✔
261
            symbolic::Lt(k_var, kernel_shape_[i]),
17✔
262
            symbolic::zero(),
17✔
263
            symbolic::add(k_var, symbolic::one()),
17✔
264
            {},
17✔
265
            block.debug_info()
17✔
266
        );
17✔
267
        loop_scope = &for_k.root();
17✔
268
        kernel_vars.push_back(k_var);
17✔
269
    }
17✔
270

271
    // Compute input spatial indices
272
    std::vector<symbolic::Expression> input_spatial_indices;
9✔
273
    for (size_t i = 0; i < spatial_dims; ++i) {
26✔
274
        auto k_dilated = symbolic::mul(kernel_vars[i], dilations_vec[i]);
17✔
275
        auto input_idx = symbolic::
17✔
276
            add(symbolic::sub(symbolic::mul(output_spatial_vars[i], strides_vec[i]), pads_begin_vec[i]), k_dilated);
17✔
277
        input_spatial_indices.push_back(input_idx);
17✔
278
    }
17✔
279

280
    // Build X indices: [n, c, input_spatial...]
281
    std::vector<symbolic::Expression> x_indices_vec = {n_var, c_var};
9✔
282
    x_indices_vec.insert(x_indices_vec.end(), input_spatial_indices.begin(), input_spatial_indices.end());
9✔
283
    data_flow::Subset x_subset(x_indices_vec);
9✔
284

285
    // Computation block: accumulate
286
    auto& comp_block = builder.add_block(*loop_scope, {}, block.debug_info());
9✔
287
    auto& x_access = builder.add_access(comp_block, X_var, x_node->debug_info());
9✔
288
    auto& accum_read = builder.add_access(comp_block, accum_var, block.debug_info());
9✔
289
    auto& accum_write = builder.add_access(comp_block, accum_var, block.debug_info());
9✔
290

291
    if (mode_ == PoolingMode::Max) {
9✔
292
        bool is_int = types::is_integer(primitive_type);
7✔
293
        if (is_int) {
7✔
294
            auto tasklet_code = TensorNode::get_integer_minmax_tasklet(primitive_type, true);
1✔
295
            auto& tasklet = builder.add_tasklet(comp_block, tasklet_code, "_out", {"_in1", "_in2"}, block.debug_info());
1✔
296
            builder.add_computational_memlet(
1✔
297
                comp_block, x_access, tasklet, "_in1", x_subset, x_edge->base_type(), block.debug_info()
1✔
298
            );
1✔
299
            builder
1✔
300
                .add_computational_memlet(comp_block, accum_read, tasklet, "_in2", {}, scalar_type, block.debug_info());
1✔
301
            builder
1✔
302
                .add_computational_memlet(comp_block, tasklet, "_out", accum_write, {}, scalar_type, block.debug_info());
1✔
303
        } else {
6✔
304
            auto& libnode = builder.add_library_node<
6✔
305
                math::cmath::CMathNode>(comp_block, block.debug_info(), cmath::CMathFunction::fmax, primitive_type);
6✔
306
            builder.add_computational_memlet(
6✔
307
                comp_block, x_access, libnode, "_in1", x_subset, x_edge->base_type(), block.debug_info()
6✔
308
            );
6✔
309
            builder
6✔
310
                .add_computational_memlet(comp_block, accum_read, libnode, "_in2", {}, scalar_type, block.debug_info());
6✔
311
            builder
6✔
312
                .add_computational_memlet(comp_block, libnode, "_out", accum_write, {}, scalar_type, block.debug_info());
6✔
313
        }
6✔
314
    } else {
7✔
315
        // Sum or Avg: accumulate with addition
316
        bool is_int = types::is_integer(primitive_type);
2✔
317
        data_flow::TaskletCode opcode = is_int ? data_flow::TaskletCode::int_add : data_flow::TaskletCode::fp_add;
2✔
318
        auto& tasklet = builder.add_tasklet(comp_block, opcode, "_out", {"_in1", "_in2"}, block.debug_info());
2✔
319
        builder.add_computational_memlet(
2✔
320
            comp_block, x_access, tasklet, "_in1", x_subset, x_edge->base_type(), block.debug_info()
2✔
321
        );
2✔
322
        builder.add_computational_memlet(comp_block, accum_read, tasklet, "_in2", {}, scalar_type, block.debug_info());
2✔
323
        builder.add_computational_memlet(comp_block, tasklet, "_out", accum_write, {}, scalar_type, block.debug_info());
2✔
324
    }
2✔
325

326
    // After kernel loops: write result to output
327
    data_flow::Subset y_subset(output_indices);
9✔
328

329
    auto& output_block = builder.add_block(*current_scope, {}, block.debug_info());
9✔
330
    auto& accum_final = builder.add_access(output_block, accum_var, block.debug_info());
9✔
331
    auto& y_access = builder.add_access(output_block, Y_var, y_node->debug_info());
9✔
332

333
    if (mode_ == PoolingMode::Avg) {
9✔
334
        // Divide by window size: product of kernel_shape dimensions
335
        // Create a temporary for the divisor
336
        std::string divisor_var = builder.find_new_name("_pool_divisor");
1✔
337
        builder.add_container(divisor_var, scalar_type);
1✔
338

339
        // Compute window size as product of kernel dimensions
340
        symbolic::Expression window_size = kernel_shape_[0];
1✔
341
        for (size_t i = 1; i < spatial_dims; ++i) {
2✔
342
            window_size = symbolic::mul(window_size, kernel_shape_[i]);
1✔
343
        }
1✔
344

345
        auto& divisor_const =
1✔
346
            builder.add_constant(output_block, window_size->__str__(), scalar_type, block.debug_info());
1✔
347
        auto& divisor_access = builder.add_access(output_block, divisor_var, block.debug_info());
1✔
348
        auto& divisor_assign =
1✔
349
            builder.add_tasklet(output_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
1✔
350
        builder.add_computational_memlet(
1✔
351
            output_block, divisor_const, divisor_assign, "_in", {}, scalar_type, block.debug_info()
1✔
352
        );
1✔
353
        builder.add_computational_memlet(
1✔
354
            output_block, divisor_assign, "_out", divisor_access, {}, scalar_type, block.debug_info()
1✔
355
        );
1✔
356

357
        bool is_int = types::is_integer(primitive_type);
1✔
358
        data_flow::TaskletCode div_opcode = is_int ? data_flow::TaskletCode::int_sdiv : data_flow::TaskletCode::fp_div;
1✔
359
        auto& div_tasklet = builder.add_tasklet(output_block, div_opcode, "_out", {"_in1", "_in2"}, block.debug_info());
1✔
360
        builder
1✔
361
            .add_computational_memlet(output_block, accum_final, div_tasklet, "_in1", {}, scalar_type, block.debug_info());
1✔
362
        builder.add_computational_memlet(
1✔
363
            output_block, divisor_access, div_tasklet, "_in2", {}, scalar_type, block.debug_info()
1✔
364
        );
1✔
365
        builder.add_computational_memlet(
1✔
366
            output_block, div_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
1✔
367
        );
1✔
368
    } else {
8✔
369
        // Max or Sum: just assign
370
        auto& assign_tasklet =
8✔
371
            builder.add_tasklet(output_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
8✔
372
        builder.add_computational_memlet(
8✔
373
            output_block, accum_final, assign_tasklet, "_in", {}, scalar_type, block.debug_info()
8✔
374
        );
8✔
375
        builder.add_computational_memlet(
8✔
376
            output_block, assign_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
8✔
377
        );
8✔
378
    }
8✔
379

380
    // Clean up original block
381
    builder.remove_memlet(block, *x_edge);
9✔
382
    builder.remove_memlet(block, y_edge);
9✔
383
    builder.remove_node(block, *x_node);
9✔
384
    builder.remove_node(block, *y_node);
9✔
385
    builder.remove_node(block, *this);
9✔
386
    builder.remove_child(parent, index + 1);
9✔
387

388
    return true;
9✔
389
}
9✔
390

391
symbolic::SymbolSet PoolingNode::symbols() const {
3✔
392
    symbolic::SymbolSet syms;
3✔
393
    for (auto& expr : shape_) {
12✔
394
        for (auto& atom : symbolic::atoms(expr)) {
12✔
395
            syms.insert(atom);
6✔
396
        }
6✔
397
    }
12✔
398
    for (auto& expr : kernel_shape_) {
6✔
399
        for (auto& atom : symbolic::atoms(expr)) {
6✔
NEW
400
            syms.insert(atom);
×
NEW
401
        }
×
402
    }
6✔
403
    for (auto& expr : strides_) {
3✔
404
        for (auto& atom : symbolic::atoms(expr)) {
2✔
NEW
405
            syms.insert(atom);
×
NEW
406
        }
×
407
    }
2✔
408
    for (auto& expr : pads_) {
3✔
NEW
409
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
410
            syms.insert(atom);
×
NEW
411
        }
×
NEW
412
    }
×
413
    for (auto& expr : dilations_) {
3✔
NEW
414
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
415
            syms.insert(atom);
×
NEW
416
        }
×
NEW
417
    }
×
418
    return syms;
3✔
419
}
3✔
420

421
void PoolingNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
1✔
422
    for (auto& expr : shape_) {
4✔
423
        expr = symbolic::subs(expr, old_expression, new_expression);
4✔
424
    }
4✔
425
    for (auto& expr : kernel_shape_) {
2✔
426
        expr = symbolic::subs(expr, old_expression, new_expression);
2✔
427
    }
2✔
428
    for (auto& expr : strides_) {
1✔
NEW
429
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
430
    }
×
431
    for (auto& expr : pads_) {
1✔
NEW
432
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
433
    }
×
434
    for (auto& expr : dilations_) {
1✔
NEW
435
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
436
    }
×
437
}
1✔
438

439
std::unique_ptr<data_flow::DataFlowNode> PoolingNode::
NEW
440
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
441
    return std::unique_ptr<data_flow::DataFlowNode>(new PoolingNode(
×
NEW
442
        element_id, this->debug_info(), vertex, parent, mode_, shape_, kernel_shape_, strides_, pads_, dilations_
×
NEW
443
    ));
×
NEW
444
}
×
445

446
std::string PoolingNode::mode_to_string(PoolingMode mode) {
3✔
447
    switch (mode) {
3✔
448
        case PoolingMode::Max:
1✔
449
            return "max";
1✔
450
        case PoolingMode::Sum:
1✔
451
            return "sum";
1✔
452
        case PoolingMode::Avg:
1✔
453
            return "avg";
1✔
454
    }
3✔
NEW
455
    return "unknown";
×
456
}
3✔
457

458
PoolingMode PoolingNode::string_to_mode(const std::string& str) {
3✔
459
    if (str == "max") return PoolingMode::Max;
3✔
460
    if (str == "sum") return PoolingMode::Sum;
2✔
461
    if (str == "avg") return PoolingMode::Avg;
1✔
NEW
462
    throw InvalidSDFGException("Unknown pooling mode: " + str);
×
463
}
1✔
464

NEW
465
std::string PoolingNode::toStr() const {
×
NEW
466
    std::string result = "Pooling(mode=" + mode_to_string(mode_) + ", shape=[";
×
NEW
467
    for (size_t i = 0; i < shape_.size(); ++i) {
×
NEW
468
        if (i > 0) result += ", ";
×
NEW
469
        result += shape_[i]->__str__();
×
NEW
470
    }
×
NEW
471
    result += "], kernel_shape=[";
×
NEW
472
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
NEW
473
        if (i > 0) result += ", ";
×
NEW
474
        result += kernel_shape_[i]->__str__();
×
NEW
475
    }
×
NEW
476
    result += "], strides=[";
×
NEW
477
    for (size_t i = 0; i < strides_.size(); ++i) {
×
NEW
478
        if (i > 0) result += ", ";
×
NEW
479
        result += strides_[i]->__str__();
×
NEW
480
    }
×
NEW
481
    result += "])";
×
NEW
482
    return result;
×
NEW
483
}
×
484

NEW
485
nlohmann::json PoolingNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
NEW
486
    const PoolingNode& node = static_cast<const PoolingNode&>(library_node);
×
NEW
487
    nlohmann::json j;
×
488

NEW
489
    j["code"] = node.code().value();
×
NEW
490
    j["mode"] = PoolingNode::mode_to_string(node.mode());
×
491

NEW
492
    serializer::JSONSerializer serializer;
×
493

NEW
494
    j["shape"] = nlohmann::json::array();
×
NEW
495
    for (auto& dim : node.shape()) {
×
NEW
496
        j["shape"].push_back(serializer.expression(dim));
×
NEW
497
    }
×
498

NEW
499
    j["kernel_shape"] = nlohmann::json::array();
×
NEW
500
    for (auto& dim : node.kernel_shape()) {
×
NEW
501
        j["kernel_shape"].push_back(serializer.expression(dim));
×
NEW
502
    }
×
503

NEW
504
    j["strides"] = nlohmann::json::array();
×
NEW
505
    for (auto& stride : node.strides()) {
×
NEW
506
        j["strides"].push_back(serializer.expression(stride));
×
NEW
507
    }
×
508

NEW
509
    j["pads"] = nlohmann::json::array();
×
NEW
510
    for (auto& pad : node.pads()) {
×
NEW
511
        j["pads"].push_back(serializer.expression(pad));
×
NEW
512
    }
×
513

NEW
514
    j["dilations"] = nlohmann::json::array();
×
NEW
515
    for (auto& dilation : node.dilations()) {
×
NEW
516
        j["dilations"].push_back(serializer.expression(dilation));
×
NEW
517
    }
×
518

NEW
519
    return j;
×
NEW
520
}
×
521

522
data_flow::LibraryNode& PoolingNodeSerializer::deserialize(
523
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
NEW
524
) {
×
NEW
525
    assert(j.contains("element_id"));
×
NEW
526
    assert(j.contains("code"));
×
NEW
527
    assert(j.contains("debug_info"));
×
NEW
528
    assert(j.contains("mode"));
×
NEW
529
    assert(j.contains("kernel_shape"));
×
530

NEW
531
    auto mode = PoolingNode::string_to_mode(j["mode"].get<std::string>());
×
532

NEW
533
    std::vector<symbolic::Expression> shape;
×
NEW
534
    if (j.contains("shape")) {
×
NEW
535
        for (const auto& dim : j["shape"]) {
×
NEW
536
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
537
        }
×
NEW
538
    }
×
539

NEW
540
    std::vector<symbolic::Expression> kernel_shape;
×
NEW
541
    for (const auto& dim : j["kernel_shape"]) {
×
NEW
542
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
543
    }
×
544

NEW
545
    std::vector<symbolic::Expression> strides;
×
NEW
546
    if (j.contains("strides")) {
×
NEW
547
        for (const auto& stride : j["strides"]) {
×
NEW
548
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
NEW
549
        }
×
NEW
550
    }
×
551

NEW
552
    std::vector<symbolic::Expression> pads;
×
NEW
553
    if (j.contains("pads")) {
×
NEW
554
        for (const auto& pad : j["pads"]) {
×
NEW
555
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
NEW
556
        }
×
NEW
557
    }
×
558

NEW
559
    std::vector<symbolic::Expression> dilations;
×
NEW
560
    if (j.contains("dilations")) {
×
NEW
561
        for (const auto& dilation : j["dilations"]) {
×
NEW
562
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
NEW
563
        }
×
NEW
564
    }
×
565

NEW
566
    sdfg::serializer::JSONSerializer serializer;
×
NEW
567
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
568

NEW
569
    return builder
×
NEW
570
        .add_library_node<PoolingNode>(parent, debug_info, mode, shape, kernel_shape, strides, pads, dilations);
×
NEW
571
}
×
572

573
} // namespace tensor
574
} // namespace math
575
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc