• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23047631600

13 Mar 2026 10:54AM UTC coverage: 62.742%. First build
23047631600

Pull #582

github

web-flow
Merge 5ef470344 into 16e11e295
Pull Request #582: [MLIR] Pooling layer Support

2 of 466 new or added lines in 3 files covered. (0.43%)

24711 of 39385 relevant lines covered (62.74%)

364.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/sdfg/src/data_flow/library_nodes/math/tensor/pooling_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/pooling_node.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/analysis/scope_analysis.h"
5
#include "sdfg/builder/structured_sdfg_builder.h"
6
#include "sdfg/data_flow/library_nodes/math/cmath/cmath_node.h"
7
#include "sdfg/data_flow/library_nodes/math/tensor/tensor_node.h"
8
#include "sdfg/types/type.h"
9

10
namespace sdfg {
11
namespace math {
12
namespace tensor {
13

14
PoolingNode::PoolingNode(
15
    size_t element_id,
16
    const DebugInfo& debug_info,
17
    const graph::Vertex vertex,
18
    data_flow::DataFlowGraph& parent,
19
    PoolingMode mode,
20
    const std::vector<symbolic::Expression>& shape,
21
    const std::vector<symbolic::Expression>& kernel_shape,
22
    const std::vector<symbolic::Expression>& strides,
23
    const std::vector<symbolic::Expression>& pads,
24
    const std::vector<symbolic::Expression>& dilations
25
)
NEW
26
    : TensorNode(
×
NEW
27
          element_id, debug_info, vertex, parent, LibraryNodeType_Pooling, {"Y"}, {"X"}, data_flow::ImplementationType_NONE
×
NEW
28
      ),
×
NEW
29
      mode_(mode), shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations) {}
×
30

NEW
31
void PoolingNode::validate(const Function& function) const {
×
NEW
32
    TensorNode::validate(function);
×
33

NEW
34
    if (kernel_shape_.empty()) {
×
NEW
35
        throw InvalidSDFGException("PoolingNode kernel_shape cannot be empty");
×
NEW
36
    }
×
37

NEW
38
    size_t spatial_dims = kernel_shape_.size();
×
39

NEW
40
    if (!strides_.empty() && strides_.size() != spatial_dims) {
×
NEW
41
        throw InvalidSDFGException("PoolingNode strides must match kernel spatial dimensions");
×
NEW
42
    }
×
43

NEW
44
    if (!pads_.empty() && pads_.size() != 2 * spatial_dims) {
×
NEW
45
        throw InvalidSDFGException("PoolingNode pads must have 2 * spatial dimensions");
×
NEW
46
    }
×
47

NEW
48
    if (!dilations_.empty() && dilations_.size() != spatial_dims) {
×
NEW
49
        throw InvalidSDFGException("PoolingNode dilations must match kernel spatial dimensions");
×
NEW
50
    }
×
NEW
51
}
×
52

NEW
53
bool PoolingNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
NEW
54
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
55

NEW
56
    auto& dataflow = this->get_parent();
×
NEW
57
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
NEW
58
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
NEW
59
    int index = parent.index(block);
×
NEW
60
    auto& transition = parent.at(index).second;
×
61

NEW
62
    auto primitive_type = this->primitive_type(dataflow);
×
NEW
63
    types::Scalar scalar_type(primitive_type);
×
64

NEW
65
    auto in_edges = dataflow.in_edges(*this);
×
NEW
66
    data_flow::Memlet* x_edge = nullptr;
×
NEW
67
    auto in_edges_it = in_edges.begin();
×
NEW
68
    while (in_edges_it != in_edges.end()) {
×
NEW
69
        auto& edge = *in_edges_it;
×
NEW
70
        if (edge.dst_conn() == "X") {
×
NEW
71
            x_edge = &edge;
×
NEW
72
        }
×
NEW
73
        ++in_edges_it;
×
NEW
74
    }
×
NEW
75
    if (!x_edge) {
×
NEW
76
        return false;
×
NEW
77
    }
×
78

NEW
79
    auto& y_edge = *dataflow.out_edges(*this).begin();
×
80

NEW
81
    auto* x_node = static_cast<data_flow::AccessNode*>(&x_edge->src());
×
NEW
82
    auto* y_node = static_cast<data_flow::AccessNode*>(&y_edge.dst());
×
83

NEW
84
    if (!x_node || dataflow.in_degree(*x_node) != 0 || !y_node || dataflow.out_degree(*y_node) != 0) {
×
NEW
85
        return false;
×
NEW
86
    }
×
87

NEW
88
    size_t spatial_dims = kernel_shape_.size();
×
NEW
89
    if (spatial_dims == 0) {
×
NEW
90
        return false;
×
NEW
91
    }
×
92

93
    // Get strides (default to 1)
NEW
94
    std::vector<symbolic::Expression> strides_vec;
×
NEW
95
    for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
96
        if (i < strides_.size()) {
×
NEW
97
            strides_vec.push_back(strides_[i]);
×
NEW
98
        } else {
×
NEW
99
            strides_vec.push_back(symbolic::one());
×
NEW
100
        }
×
NEW
101
    }
×
102

103
    // Get padding (default to 0)
NEW
104
    std::vector<symbolic::Expression> pads_begin_vec, pads_end_vec;
×
NEW
105
    for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
106
        if (i < pads_.size()) {
×
NEW
107
            pads_begin_vec.push_back(pads_[i]);
×
NEW
108
        } else {
×
NEW
109
            pads_begin_vec.push_back(symbolic::zero());
×
NEW
110
        }
×
NEW
111
        if (spatial_dims + i < pads_.size()) {
×
NEW
112
            pads_end_vec.push_back(pads_[spatial_dims + i]);
×
NEW
113
        } else {
×
NEW
114
            pads_end_vec.push_back(symbolic::zero());
×
NEW
115
        }
×
NEW
116
    }
×
117

118
    // Get dilations (default to 1)
NEW
119
    std::vector<symbolic::Expression> dilations_vec;
×
NEW
120
    for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
121
        if (i < dilations_.size()) {
×
NEW
122
            dilations_vec.push_back(dilations_[i]);
×
NEW
123
        } else {
×
NEW
124
            dilations_vec.push_back(symbolic::one());
×
NEW
125
        }
×
NEW
126
    }
×
127

NEW
128
    auto& X_var = x_node->data();
×
NEW
129
    auto& Y_var = y_node->data();
×
130

131
    // Input shape: [N, C, D0, D1, ..., Dn]
NEW
132
    symbolic::Expression N = shape_[0];
×
NEW
133
    symbolic::Expression C = shape_[1];
×
NEW
134
    std::vector<symbolic::Expression> input_spatial_dims;
×
NEW
135
    for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
136
        input_spatial_dims.push_back(shape_[2 + i]);
×
NEW
137
    }
×
138

139
    // Output spatial dimensions
NEW
140
    std::vector<symbolic::Expression> output_spatial_dims;
×
NEW
141
    for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
142
        auto d_in = input_spatial_dims[i];
×
NEW
143
        auto pad = symbolic::add(pads_begin_vec[i], pads_end_vec[i]);
×
NEW
144
        auto dk = symbolic::mul(dilations_vec[i], symbolic::sub(kernel_shape_[i], symbolic::one()));
×
NEW
145
        auto num = symbolic::sub(symbolic::add(d_in, pad), symbolic::add(dk, symbolic::one()));
×
NEW
146
        auto d_out = symbolic::add(symbolic::div(num, strides_vec[i]), symbolic::one());
×
NEW
147
        output_spatial_dims.push_back(d_out);
×
NEW
148
    }
×
149

NEW
150
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
×
151

NEW
152
    structured_control_flow::Sequence* current_scope = &new_sequence;
×
NEW
153
    std::vector<symbolic::Expression> output_indices;
×
NEW
154
    std::vector<symbolic::Expression> output_spatial_vars;
×
155

156
    // Map over batch
NEW
157
    std::string n_str = builder.find_new_name("n");
×
NEW
158
    builder.add_container(n_str, types::Scalar(types::PrimitiveType::UInt64));
×
NEW
159
    auto n_var = symbolic::symbol(n_str);
×
NEW
160
    auto& map_n = builder.add_map(
×
NEW
161
        *current_scope,
×
NEW
162
        n_var,
×
NEW
163
        symbolic::Lt(n_var, N),
×
NEW
164
        symbolic::zero(),
×
NEW
165
        symbolic::add(n_var, symbolic::one()),
×
NEW
166
        structured_control_flow::ScheduleType_Sequential::create(),
×
NEW
167
        {},
×
NEW
168
        block.debug_info()
×
NEW
169
    );
×
NEW
170
    current_scope = &map_n.root();
×
NEW
171
    output_indices.push_back(n_var);
×
172

173
    // Map over channel
NEW
174
    std::string c_str = builder.find_new_name("c");
×
NEW
175
    builder.add_container(c_str, types::Scalar(types::PrimitiveType::UInt64));
×
NEW
176
    auto c_var = symbolic::symbol(c_str);
×
NEW
177
    auto& map_c = builder.add_map(
×
NEW
178
        *current_scope,
×
NEW
179
        c_var,
×
NEW
180
        symbolic::Lt(c_var, C),
×
NEW
181
        symbolic::zero(),
×
NEW
182
        symbolic::add(c_var, symbolic::one()),
×
NEW
183
        structured_control_flow::ScheduleType_Sequential::create(),
×
NEW
184
        {},
×
NEW
185
        block.debug_info()
×
NEW
186
    );
×
NEW
187
    current_scope = &map_c.root();
×
NEW
188
    output_indices.push_back(c_var);
×
189

190
    // Map over each output spatial dimension
NEW
191
    for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
192
        std::string od_str = builder.find_new_name("od" + std::to_string(i));
×
NEW
193
        builder.add_container(od_str, types::Scalar(types::PrimitiveType::UInt64));
×
NEW
194
        auto od_var = symbolic::symbol(od_str);
×
NEW
195
        auto& map_od = builder.add_map(
×
NEW
196
            *current_scope,
×
NEW
197
            od_var,
×
NEW
198
            symbolic::Lt(od_var, output_spatial_dims[i]),
×
NEW
199
            symbolic::zero(),
×
NEW
200
            symbolic::add(od_var, symbolic::one()),
×
NEW
201
            structured_control_flow::ScheduleType_Sequential::create(),
×
NEW
202
            {},
×
NEW
203
            block.debug_info()
×
NEW
204
        );
×
NEW
205
        current_scope = &map_od.root();
×
NEW
206
        output_indices.push_back(od_var);
×
NEW
207
        output_spatial_vars.push_back(od_var);
×
NEW
208
    }
×
209

210
    // Create accumulator
NEW
211
    std::string accum_var = builder.find_new_name("_pool_accum");
×
NEW
212
    builder.add_container(accum_var, scalar_type);
×
213

214
    // Initialize accumulator
NEW
215
    std::string init_value;
×
NEW
216
    if (mode_ == PoolingMode::Max) {
×
217
        // Use -INFINITY for float, type-min for integers
NEW
218
        if (types::is_integer(primitive_type)) {
×
NEW
219
            switch (primitive_type) {
×
NEW
220
                case types::PrimitiveType::Int8:
×
NEW
221
                    init_value = "INT8_MIN";
×
NEW
222
                    break;
×
NEW
223
                case types::PrimitiveType::Int16:
×
NEW
224
                    init_value = "INT16_MIN";
×
NEW
225
                    break;
×
NEW
226
                case types::PrimitiveType::Int32:
×
NEW
227
                    init_value = "INT32_MIN";
×
NEW
228
                    break;
×
NEW
229
                case types::PrimitiveType::Int64:
×
NEW
230
                    init_value = "INT64_MIN";
×
NEW
231
                    break;
×
NEW
232
                default:
×
NEW
233
                    init_value = "0";
×
NEW
234
                    break;
×
NEW
235
            }
×
NEW
236
        } else {
×
NEW
237
            init_value = "-INFINITY";
×
NEW
238
        }
×
NEW
239
    } else {
×
240
        // Sum / Avg: init to 0
NEW
241
        init_value = types::is_integer(primitive_type) ? "0" : "0.0";
×
NEW
242
    }
×
243

NEW
244
    auto& init_block = builder.add_block(*current_scope, {}, block.debug_info());
×
NEW
245
    auto& accum_init = builder.add_access(init_block, accum_var, block.debug_info());
×
NEW
246
    auto& zero_const = builder.add_constant(init_block, init_value, scalar_type, block.debug_info());
×
NEW
247
    auto& init_tasklet = builder.add_tasklet(init_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
×
NEW
248
    builder.add_computational_memlet(init_block, zero_const, init_tasklet, "_in", {}, scalar_type, block.debug_info());
×
NEW
249
    builder.add_computational_memlet(init_block, init_tasklet, "_out", accum_init, {}, scalar_type, block.debug_info());
×
250

251
    // For loops over kernel spatial dimensions
NEW
252
    auto* loop_scope = current_scope;
×
NEW
253
    std::vector<symbolic::Expression> kernel_vars;
×
NEW
254
    for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
255
        std::string k_str = builder.find_new_name("k" + std::to_string(i));
×
NEW
256
        builder.add_container(k_str, types::Scalar(types::PrimitiveType::UInt64));
×
NEW
257
        auto k_var = symbolic::symbol(k_str);
×
NEW
258
        auto& for_k = builder.add_for(
×
NEW
259
            *loop_scope,
×
NEW
260
            k_var,
×
NEW
261
            symbolic::Lt(k_var, kernel_shape_[i]),
×
NEW
262
            symbolic::zero(),
×
NEW
263
            symbolic::add(k_var, symbolic::one()),
×
NEW
264
            {},
×
NEW
265
            block.debug_info()
×
NEW
266
        );
×
NEW
267
        loop_scope = &for_k.root();
×
NEW
268
        kernel_vars.push_back(k_var);
×
NEW
269
    }
×
270

271
    // Compute input spatial indices
NEW
272
    std::vector<symbolic::Expression> input_spatial_indices;
×
NEW
273
    for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
274
        auto k_dilated = symbolic::mul(kernel_vars[i], dilations_vec[i]);
×
NEW
275
        auto input_idx = symbolic::
×
NEW
276
            add(symbolic::sub(symbolic::mul(output_spatial_vars[i], strides_vec[i]), pads_begin_vec[i]), k_dilated);
×
NEW
277
        input_spatial_indices.push_back(input_idx);
×
NEW
278
    }
×
279

280
    // Build X indices: [n, c, input_spatial...]
NEW
281
    std::vector<symbolic::Expression> x_indices_vec = {n_var, c_var};
×
NEW
282
    x_indices_vec.insert(x_indices_vec.end(), input_spatial_indices.begin(), input_spatial_indices.end());
×
NEW
283
    data_flow::Subset x_subset(x_indices_vec);
×
284

285
    // Computation block: accumulate
NEW
286
    auto& comp_block = builder.add_block(*loop_scope, {}, block.debug_info());
×
NEW
287
    auto& x_access = builder.add_access(comp_block, X_var, x_node->debug_info());
×
NEW
288
    auto& accum_read = builder.add_access(comp_block, accum_var, block.debug_info());
×
NEW
289
    auto& accum_write = builder.add_access(comp_block, accum_var, block.debug_info());
×
290

NEW
291
    if (mode_ == PoolingMode::Max) {
×
NEW
292
        bool is_int = types::is_integer(primitive_type);
×
NEW
293
        if (is_int) {
×
NEW
294
            auto tasklet_code = TensorNode::get_integer_minmax_tasklet(primitive_type, true);
×
NEW
295
            auto& tasklet = builder.add_tasklet(comp_block, tasklet_code, "_out", {"_in1", "_in2"}, block.debug_info());
×
NEW
296
            builder.add_computational_memlet(
×
NEW
297
                comp_block, x_access, tasklet, "_in1", x_subset, x_edge->base_type(), block.debug_info()
×
NEW
298
            );
×
NEW
299
            builder
×
NEW
300
                .add_computational_memlet(comp_block, accum_read, tasklet, "_in2", {}, scalar_type, block.debug_info());
×
NEW
301
            builder
×
NEW
302
                .add_computational_memlet(comp_block, tasklet, "_out", accum_write, {}, scalar_type, block.debug_info());
×
NEW
303
        } else {
×
NEW
304
            auto& libnode = builder.add_library_node<
×
NEW
305
                math::cmath::CMathNode>(comp_block, block.debug_info(), cmath::CMathFunction::fmax, primitive_type);
×
NEW
306
            builder.add_computational_memlet(
×
NEW
307
                comp_block, x_access, libnode, "_in1", x_subset, x_edge->base_type(), block.debug_info()
×
NEW
308
            );
×
NEW
309
            builder
×
NEW
310
                .add_computational_memlet(comp_block, accum_read, libnode, "_in2", {}, scalar_type, block.debug_info());
×
NEW
311
            builder
×
NEW
312
                .add_computational_memlet(comp_block, libnode, "_out", accum_write, {}, scalar_type, block.debug_info());
×
NEW
313
        }
×
NEW
314
    } else {
×
315
        // Sum or Avg: accumulate with addition
NEW
316
        bool is_int = types::is_integer(primitive_type);
×
NEW
317
        data_flow::TaskletCode opcode = is_int ? data_flow::TaskletCode::int_add : data_flow::TaskletCode::fp_add;
×
NEW
318
        auto& tasklet = builder.add_tasklet(comp_block, opcode, "_out", {"_in1", "_in2"}, block.debug_info());
×
NEW
319
        builder.add_computational_memlet(
×
NEW
320
            comp_block, x_access, tasklet, "_in1", x_subset, x_edge->base_type(), block.debug_info()
×
NEW
321
        );
×
NEW
322
        builder.add_computational_memlet(comp_block, accum_read, tasklet, "_in2", {}, scalar_type, block.debug_info());
×
NEW
323
        builder.add_computational_memlet(comp_block, tasklet, "_out", accum_write, {}, scalar_type, block.debug_info());
×
NEW
324
    }
×
325

326
    // After kernel loops: write result to output
NEW
327
    data_flow::Subset y_subset(output_indices);
×
328

NEW
329
    auto& output_block = builder.add_block(*current_scope, {}, block.debug_info());
×
NEW
330
    auto& accum_final = builder.add_access(output_block, accum_var, block.debug_info());
×
NEW
331
    auto& y_access = builder.add_access(output_block, Y_var, y_node->debug_info());
×
332

NEW
333
    if (mode_ == PoolingMode::Avg) {
×
334
        // Divide by window size: product of kernel_shape dimensions
335
        // Create a temporary for the divisor
NEW
336
        std::string divisor_var = builder.find_new_name("_pool_divisor");
×
NEW
337
        builder.add_container(divisor_var, scalar_type);
×
338

339
        // Compute window size as product of kernel dimensions
NEW
340
        symbolic::Expression window_size = kernel_shape_[0];
×
NEW
341
        for (size_t i = 1; i < spatial_dims; ++i) {
×
NEW
342
            window_size = symbolic::mul(window_size, kernel_shape_[i]);
×
NEW
343
        }
×
344

NEW
345
        auto& divisor_const =
×
NEW
346
            builder.add_constant(output_block, window_size->__str__(), scalar_type, block.debug_info());
×
NEW
347
        auto& divisor_access = builder.add_access(output_block, divisor_var, block.debug_info());
×
NEW
348
        auto& divisor_assign =
×
NEW
349
            builder.add_tasklet(output_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
×
NEW
350
        builder.add_computational_memlet(
×
NEW
351
            output_block, divisor_const, divisor_assign, "_in", {}, scalar_type, block.debug_info()
×
NEW
352
        );
×
NEW
353
        builder.add_computational_memlet(
×
NEW
354
            output_block, divisor_assign, "_out", divisor_access, {}, scalar_type, block.debug_info()
×
NEW
355
        );
×
356

NEW
357
        bool is_int = types::is_integer(primitive_type);
×
NEW
358
        data_flow::TaskletCode div_opcode = is_int ? data_flow::TaskletCode::int_sdiv : data_flow::TaskletCode::fp_div;
×
NEW
359
        auto& div_tasklet = builder.add_tasklet(output_block, div_opcode, "_out", {"_in1", "_in2"}, block.debug_info());
×
NEW
360
        builder
×
NEW
361
            .add_computational_memlet(output_block, accum_final, div_tasklet, "_in1", {}, scalar_type, block.debug_info());
×
NEW
362
        builder.add_computational_memlet(
×
NEW
363
            output_block, divisor_access, div_tasklet, "_in2", {}, scalar_type, block.debug_info()
×
NEW
364
        );
×
NEW
365
        builder.add_computational_memlet(
×
NEW
366
            output_block, div_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
×
NEW
367
        );
×
NEW
368
    } else {
×
369
        // Max or Sum: just assign
NEW
370
        auto& assign_tasklet =
×
NEW
371
            builder.add_tasklet(output_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
×
NEW
372
        builder.add_computational_memlet(
×
NEW
373
            output_block, accum_final, assign_tasklet, "_in", {}, scalar_type, block.debug_info()
×
NEW
374
        );
×
NEW
375
        builder.add_computational_memlet(
×
NEW
376
            output_block, assign_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
×
NEW
377
        );
×
NEW
378
    }
×
379

380
    // Clean up original block
NEW
381
    builder.remove_memlet(block, *x_edge);
×
NEW
382
    builder.remove_memlet(block, y_edge);
×
NEW
383
    builder.remove_node(block, *x_node);
×
NEW
384
    builder.remove_node(block, *y_node);
×
NEW
385
    builder.remove_node(block, *this);
×
NEW
386
    builder.remove_child(parent, index + 1);
×
387

NEW
388
    return true;
×
NEW
389
}
×
390

NEW
391
symbolic::SymbolSet PoolingNode::symbols() const {
×
NEW
392
    symbolic::SymbolSet syms;
×
NEW
393
    for (auto& expr : shape_) {
×
NEW
394
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
395
            syms.insert(atom);
×
NEW
396
        }
×
NEW
397
    }
×
NEW
398
    for (auto& expr : kernel_shape_) {
×
NEW
399
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
400
            syms.insert(atom);
×
NEW
401
        }
×
NEW
402
    }
×
NEW
403
    for (auto& expr : strides_) {
×
NEW
404
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
405
            syms.insert(atom);
×
NEW
406
        }
×
NEW
407
    }
×
NEW
408
    for (auto& expr : pads_) {
×
NEW
409
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
410
            syms.insert(atom);
×
NEW
411
        }
×
NEW
412
    }
×
NEW
413
    for (auto& expr : dilations_) {
×
NEW
414
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
415
            syms.insert(atom);
×
NEW
416
        }
×
NEW
417
    }
×
NEW
418
    return syms;
×
NEW
419
}
×
420

NEW
421
void PoolingNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
NEW
422
    for (auto& expr : shape_) {
×
NEW
423
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
424
    }
×
NEW
425
    for (auto& expr : kernel_shape_) {
×
NEW
426
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
427
    }
×
NEW
428
    for (auto& expr : strides_) {
×
NEW
429
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
430
    }
×
NEW
431
    for (auto& expr : pads_) {
×
NEW
432
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
433
    }
×
NEW
434
    for (auto& expr : dilations_) {
×
NEW
435
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
436
    }
×
NEW
437
}
×
438

439
std::unique_ptr<data_flow::DataFlowNode> PoolingNode::
NEW
440
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
441
    return std::unique_ptr<data_flow::DataFlowNode>(new PoolingNode(
×
NEW
442
        element_id, this->debug_info(), vertex, parent, mode_, shape_, kernel_shape_, strides_, pads_, dilations_
×
NEW
443
    ));
×
NEW
444
}
×
445

NEW
446
std::string PoolingNode::mode_to_string(PoolingMode mode) {
×
NEW
447
    switch (mode) {
×
NEW
448
        case PoolingMode::Max:
×
NEW
449
            return "max";
×
NEW
450
        case PoolingMode::Sum:
×
NEW
451
            return "sum";
×
NEW
452
        case PoolingMode::Avg:
×
NEW
453
            return "avg";
×
NEW
454
    }
×
NEW
455
    return "unknown";
×
NEW
456
}
×
457

NEW
458
PoolingMode PoolingNode::string_to_mode(const std::string& str) {
×
NEW
459
    if (str == "max") return PoolingMode::Max;
×
NEW
460
    if (str == "sum") return PoolingMode::Sum;
×
NEW
461
    if (str == "avg") return PoolingMode::Avg;
×
NEW
462
    throw InvalidSDFGException("Unknown pooling mode: " + str);
×
NEW
463
}
×
464

NEW
465
std::string PoolingNode::toStr() const {
×
NEW
466
    std::string result = "Pooling(mode=" + mode_to_string(mode_) + ", shape=[";
×
NEW
467
    for (size_t i = 0; i < shape_.size(); ++i) {
×
NEW
468
        if (i > 0) result += ", ";
×
NEW
469
        result += shape_[i]->__str__();
×
NEW
470
    }
×
NEW
471
    result += "], kernel_shape=[";
×
NEW
472
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
NEW
473
        if (i > 0) result += ", ";
×
NEW
474
        result += kernel_shape_[i]->__str__();
×
NEW
475
    }
×
NEW
476
    result += "], strides=[";
×
NEW
477
    for (size_t i = 0; i < strides_.size(); ++i) {
×
NEW
478
        if (i > 0) result += ", ";
×
NEW
479
        result += strides_[i]->__str__();
×
NEW
480
    }
×
NEW
481
    result += "])";
×
NEW
482
    return result;
×
NEW
483
}
×
484

NEW
485
nlohmann::json PoolingNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
NEW
486
    const PoolingNode& node = static_cast<const PoolingNode&>(library_node);
×
NEW
487
    nlohmann::json j;
×
488

NEW
489
    j["code"] = node.code().value();
×
NEW
490
    j["mode"] = PoolingNode::mode_to_string(node.mode());
×
491

NEW
492
    serializer::JSONSerializer serializer;
×
493

NEW
494
    j["shape"] = nlohmann::json::array();
×
NEW
495
    for (auto& dim : node.shape()) {
×
NEW
496
        j["shape"].push_back(serializer.expression(dim));
×
NEW
497
    }
×
498

NEW
499
    j["kernel_shape"] = nlohmann::json::array();
×
NEW
500
    for (auto& dim : node.kernel_shape()) {
×
NEW
501
        j["kernel_shape"].push_back(serializer.expression(dim));
×
NEW
502
    }
×
503

NEW
504
    j["strides"] = nlohmann::json::array();
×
NEW
505
    for (auto& stride : node.strides()) {
×
NEW
506
        j["strides"].push_back(serializer.expression(stride));
×
NEW
507
    }
×
508

NEW
509
    j["pads"] = nlohmann::json::array();
×
NEW
510
    for (auto& pad : node.pads()) {
×
NEW
511
        j["pads"].push_back(serializer.expression(pad));
×
NEW
512
    }
×
513

NEW
514
    j["dilations"] = nlohmann::json::array();
×
NEW
515
    for (auto& dilation : node.dilations()) {
×
NEW
516
        j["dilations"].push_back(serializer.expression(dilation));
×
NEW
517
    }
×
518

NEW
519
    return j;
×
NEW
520
}
×
521

522
data_flow::LibraryNode& PoolingNodeSerializer::deserialize(
523
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
NEW
524
) {
×
NEW
525
    assert(j.contains("element_id"));
×
NEW
526
    assert(j.contains("code"));
×
NEW
527
    assert(j.contains("debug_info"));
×
NEW
528
    assert(j.contains("mode"));
×
NEW
529
    assert(j.contains("kernel_shape"));
×
530

NEW
531
    auto mode = PoolingNode::string_to_mode(j["mode"].get<std::string>());
×
532

NEW
533
    std::vector<symbolic::Expression> shape;
×
NEW
534
    if (j.contains("shape")) {
×
NEW
535
        for (const auto& dim : j["shape"]) {
×
NEW
536
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
537
        }
×
NEW
538
    }
×
539

NEW
540
    std::vector<symbolic::Expression> kernel_shape;
×
NEW
541
    for (const auto& dim : j["kernel_shape"]) {
×
NEW
542
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
543
    }
×
544

NEW
545
    std::vector<symbolic::Expression> strides;
×
NEW
546
    if (j.contains("strides")) {
×
NEW
547
        for (const auto& stride : j["strides"]) {
×
NEW
548
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
NEW
549
        }
×
NEW
550
    }
×
551

NEW
552
    std::vector<symbolic::Expression> pads;
×
NEW
553
    if (j.contains("pads")) {
×
NEW
554
        for (const auto& pad : j["pads"]) {
×
NEW
555
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
NEW
556
        }
×
NEW
557
    }
×
558

NEW
559
    std::vector<symbolic::Expression> dilations;
×
NEW
560
    if (j.contains("dilations")) {
×
NEW
561
        for (const auto& dilation : j["dilations"]) {
×
NEW
562
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
NEW
563
        }
×
NEW
564
    }
×
565

NEW
566
    sdfg::serializer::JSONSerializer serializer;
×
NEW
567
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
568

NEW
569
    return builder
×
NEW
570
        .add_library_node<PoolingNode>(parent, debug_info, mode, shape, kernel_shape, strides, pads, dilations);
×
NEW
571
}
×
572

573
} // namespace tensor
574
} // namespace math
575
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc