• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23639449060

27 Mar 2026 09:19AM UTC coverage: 64.469% (-0.4%) from 64.845%
23639449060

Pull #613

github

web-flow
Merge 2dc01bd50 into c252af595
Pull Request #613: Use im2row/im2col as expansion of convolution node

370 of 729 new or added lines in 1 file covered. (50.75%)

13 existing lines in 2 files now uncovered.

27126 of 42076 relevant lines covered (64.47%)

403.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

45.06
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4
#include <sstream>
5

6
#include "sdfg/analysis/analysis.h"
7
#include "sdfg/builder/structured_sdfg_builder.h"
8
#include "sdfg/data_flow/access_node.h"
9
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
10
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
11
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
12
#include "sdfg/data_flow/library_nodes/stdlib/memset.h"
13
#include "sdfg/data_flow/memlet.h"
14
#include "sdfg/data_flow/tasklet.h"
15
#include "sdfg/exceptions.h"
16
#include "sdfg/structured_control_flow/block.h"
17
#include "sdfg/structured_control_flow/map.h"
18
#include "sdfg/structured_control_flow/sequence.h"
19
#include "sdfg/symbolic/symbolic.h"
20
#include "sdfg/types/pointer.h"
21
#include "sdfg/types/scalar.h"
22
#include "sdfg/types/tensor.h"
23
#include "sdfg/types/type.h"
24

25
#include "sdfg/analysis/scope_analysis.h"
26
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
27
#include "symengine/integer.h"
28
#include "symengine/symengine_rcp.h"
29

30
namespace sdfg {
31
namespace math {
32
namespace tensor {
33

34
ConvNode::ConvNode(
35
    size_t element_id,
36
    const DebugInfo& debug_info,
37
    const graph::Vertex vertex,
38
    data_flow::DataFlowGraph& parent,
39
    const std::vector<symbolic::Expression>& shape,
40
    const std::vector<symbolic::Expression>& kernel_shape,
41
    const std::vector<symbolic::Expression>& strides,
42
    const std::vector<symbolic::Expression>& pads,
43
    const std::vector<symbolic::Expression>& dilations,
44
    symbolic::Expression output_channels,
45
    symbolic::Expression group
46
)
47
    : TensorNode(
22✔
48
          element_id,
22✔
49
          debug_info,
22✔
50
          vertex,
22✔
51
          parent,
22✔
52
          LibraryNodeType_Conv,
22✔
53
          {"Y"},
22✔
54
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
22✔
55
          data_flow::ImplementationType_NONE
22✔
56
      ),
22✔
57
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations),
22✔
58
      output_channels_(output_channels), group_(group) {}
22✔
59

60
void ConvNode::validate(const Function& function) const {
17✔
61
    TensorNode::validate(function);
17✔
62

63
    auto& graph = this->get_parent();
17✔
64

65
    // Custom validation for ConvNode that handles optional bias input
66
    // We expect X, W as required inputs and optionally B (bias)
67

68
    // Collect all input edges by connector name
69
    std::map<std::string, const data_flow::Memlet*> input_edges;
17✔
70
    for (auto& iedge : graph.in_edges(*this)) {
34✔
71
        input_edges[iedge.dst_conn()] = &iedge;
34✔
72
    }
34✔
73

74
    // Check that required inputs X and W are present
75
    if (input_edges.find("X") == input_edges.end()) {
17✔
76
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
77
    }
×
78
    if (input_edges.find("W") == input_edges.end()) {
17✔
79
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
80
    }
×
81

82
    // Validate that parameters are not empty
83
    if (shape_.empty()) {
17✔
NEW
84
        throw InvalidSDFGException("ConvNode shape cannot be empty");
×
NEW
85
    }
×
86
    if (kernel_shape_.empty()) {
17✔
87
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
88
    }
×
89
    if (strides_.empty()) {
17✔
NEW
90
        throw InvalidSDFGException("ConvNode strides cannot be empty");
×
NEW
91
    }
×
92
    if (pads_.empty()) {
17✔
NEW
93
        throw InvalidSDFGException("ConvNode pads cannot be empty");
×
NEW
94
    }
×
95
    if (dilations_.empty()) {
17✔
NEW
96
        throw InvalidSDFGException("ConvNode dilations cannot be empty");
×
NEW
97
    }
×
98

99
    // Validate consistent dimensions
100
    size_t spatial_dims = kernel_shape_.size();
17✔
101

102
    if (shape_.size() != spatial_dims + 2) {
17✔
NEW
103
        throw InvalidSDFGException("ConvNode shape must match kernel spatial dimensions + 2");
×
NEW
104
    }
×
105

106
    if (strides_.size() != spatial_dims) {
17✔
107
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
108
    }
1✔
109

110
    if (pads_.size() != 2 * spatial_dims) {
16✔
111
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
112
    }
1✔
113

114
    if (dilations_.size() != spatial_dims) {
15✔
115
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
116
    }
×
117

118
    // Validate groups
119
    if (SymEngine::is_a<SymEngine::Integer>(*this->group_)) {
15✔
120
        auto group_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->group_)->as_int();
15✔
121
        if (SymEngine::is_a<SymEngine::Integer>(*this->shape_[1])) {
15✔
122
            auto input_channels_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->shape_[1])->as_int();
15✔
123
            if (input_channels_int % group_int != 0) {
15✔
NEW
124
                throw InvalidSDFGException("ConvNode input channels must be divisible by groups");
×
NEW
125
            }
×
126
        }
15✔
127
        if (SymEngine::is_a<SymEngine::Integer>(*this->output_channels_)) {
15✔
128
            auto output_channels_int =
15✔
129
                SymEngine::rcp_static_cast<const SymEngine::Integer>(this->output_channels_)->as_int();
15✔
130
            if (output_channels_int % group_int != 0) {
15✔
NEW
131
                throw InvalidSDFGException("ConvNode output channels must be divisible by groups");
×
NEW
132
            }
×
133
        }
15✔
134
    }
15✔
135
}
15✔
136

137
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
138
    // Validate nodes are standalone in the data flow graph
139
    auto& dfg = this->get_parent();
5✔
140
    if ((dfg.nodes().size() != 4 || dfg.edges().size() != 3) && (dfg.nodes().size() != 5 || dfg.edges().size() != 4)) {
5✔
NEW
141
        return false;
×
UNCOV
142
    }
×
143

144
    // Get edges
145
    auto iedges = dfg.in_edges_by_connector(*this);
5✔
146
    auto oedges = dfg.out_edges_by_connector(*this);
5✔
147
    if (iedges.size() != 3 || oedges.size() != 1) {
5✔
NEW
148
        return false;
×
NEW
149
    }
×
150
    auto* iedge_X = iedges.at(0);
5✔
151
    auto* iedge_W = iedges.at(1);
5✔
152
    auto* iedge_B = iedges.at(2);
5✔
153
    auto* oedge_Y = oedges.at(0);
5✔
154
    if (!iedge_X || !iedge_W || !oedge_Y) {
5✔
NEW
155
        return false;
×
NEW
156
    }
×
157
    bool has_bias = iedge_B != nullptr;
5✔
158

159
    // Get access nodes
160
    auto* access_X = dynamic_cast<data_flow::AccessNode*>(&iedge_X->src());
5✔
161
    auto* access_W = dynamic_cast<data_flow::AccessNode*>(&iedge_W->src());
5✔
162
    auto* access_B = (has_bias ? dynamic_cast<data_flow::AccessNode*>(&iedge_B->src()) : nullptr);
5✔
163
    auto* access_Y = dynamic_cast<data_flow::AccessNode*>(&oedge_Y->dst());
5✔
164
    if (!access_X || !access_W || (has_bias && !access_B) || !access_Y) {
5✔
165
        return false;
×
166
    }
×
167

168
    // Get block & its parent
169
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
5✔
170
    if (!block) {
5✔
NEW
171
        return false;
×
NEW
172
    }
×
173
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
174
    auto* block_parent = dynamic_cast<structured_control_flow::Sequence*>(scope_analysis.parent_scope(block));
5✔
175
    if (!block_parent) {
5✔
NEW
176
        return false;
×
NEW
177
    }
×
178
    size_t block_index = block_parent->index(*block);
5✔
179
    if (block_index >= block_parent->size()) {
5✔
180
        return false;
×
181
    }
×
182

183
    // Determine BLAS precision
184
    blas::BLAS_Precision precision;
5✔
185
    types::Scalar base_type(this->primitive_type(dfg));
5✔
186
    switch (base_type.primitive_type()) {
5✔
NEW
187
        case types::PrimitiveType::Half:
×
NEW
188
            precision = blas::BLAS_Precision::h;
×
NEW
189
            break;
×
190
        case types::PrimitiveType::Float:
5✔
191
            precision = blas::BLAS_Precision::s;
5✔
192
            break;
5✔
NEW
193
        case types::PrimitiveType::Double:
×
NEW
194
            precision = blas::BLAS_Precision::d;
×
NEW
195
            break;
×
NEW
196
        default:
×
NEW
197
            return false;
×
198
    }
5✔
199

200
    // Create new sequence for expansion
201
    auto& new_sequence = builder.add_sequence_before(
5✔
202
        *block_parent, *block, block_parent->at(block_index).second.assignments(), block->debug_info()
5✔
203
    );
5✔
204

205
    // Dimensions, i.e., 1D, 2D, 3D, ...
206
    size_t dims = this->kernel_shape_.size();
5✔
207
    symbolic::MultiExpression out_shape;
5✔
208
    out_shape.reserve(dims);
5✔
209
    // out_shape[i] = (shape[i + 2] + pads[i] + pads[dims + i] - dilations[i] * (kernel_shape[i] - 1) - 1)
210
    //                 / strides[i] + 1
211
    for (size_t i = 0; i < dims; i++) {
15✔
212
        out_shape.push_back(symbolic::add(
10✔
213
            symbolic::div(
10✔
214
                symbolic::sub(
10✔
215
                    symbolic::
10✔
216
                        sub(symbolic::add(this->shape_[i + 2], symbolic::add(this->pads_[i], this->pads_[dims + i])),
10✔
217
                            symbolic::mul(this->dilations_[i], symbolic::sub(this->kernel_shape_[i], symbolic::one()))),
10✔
218
                    symbolic::one()
10✔
219
                ),
10✔
220
                this->strides_[i]
10✔
221
            ),
10✔
222
            symbolic::one()
10✔
223
        ));
10✔
224
    }
10✔
225
    types::Scalar indvar_type(types::PrimitiveType::Int64);
5✔
226

227
    // If there are no groups (i.e., group == 1), then we can do im2row with one GEMM.
228
    // Else, we do naïve im2col with multiple GEMM's.
229
    if (symbolic::eq(this->group_, symbolic::one())) {
5✔
230
        /* ===== No groups ====================================================================== */
231

232
        // Add patches container with malloc
233
        symbolic::Expression patches_size = symbolic::mul(this->shape_[0], this->shape_[1]);
5✔
234
        for (size_t i = 0; i < dims; i++) {
15✔
235
            patches_size = symbolic::mul(patches_size, symbolic::mul(this->kernel_shape_[i], out_shape[i]));
10✔
236
        }
10✔
237
        types::Pointer patches_type(base_type);
5✔
238
        auto patches_container = builder.find_new_name("_patches");
5✔
239
        builder.add_container(patches_container, patches_type);
5✔
240
        auto& patches_malloc_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
241
        {
5✔
242
            auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
5✔
243
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
5✔
244
                patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
5✔
245
            );
5✔
246
            builder.add_computational_memlet(
5✔
247
                patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
5✔
248
            );
5✔
249
        }
5✔
250

251
        // Memset patches to zero
252
        auto& patches_memset_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
253
        {
5✔
254
            auto& patches_access = builder.add_access(patches_memset_block, patches_container, this->debug_info());
5✔
255
            auto& libnode = builder.add_library_node<stdlib::MemsetNode>(
5✔
256
                patches_memset_block,
5✔
257
                this->debug_info(),
5✔
258
                symbolic::zero(),
5✔
259
                symbolic::mul(patches_size, symbolic::size_of_type(base_type))
5✔
260
            );
5✔
261
            builder.add_computational_memlet(
5✔
262
                patches_memset_block, libnode, "_ptr", patches_access, {}, patches_type, this->debug_info()
5✔
263
            );
5✔
264
        }
5✔
265

266
        // Add malloc for temporary GEMM output
267
        symbolic::Expression tmp_Y_size = symbolic::mul(this->output_channels_, this->shape_[0]);
5✔
268
        for (size_t i = 0; i < dims; i++) {
15✔
269
            tmp_Y_size = symbolic::mul(tmp_Y_size, out_shape[i]);
10✔
270
        }
10✔
271
        auto tmp_Y_container = builder.find_new_name("_tmp_Y");
5✔
272
        types::Scalar tmp_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
5✔
273
        types::Pointer tmp_Y_type(tmp_Y_base_type);
5✔
274
        builder.add_container(tmp_Y_container, tmp_Y_type);
5✔
275
        auto& tmp_Y_malloc_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
276
        {
5✔
277
            auto& tmp_Y_access = builder.add_access(tmp_Y_malloc_block, tmp_Y_container, this->debug_info());
5✔
278
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
5✔
279
                tmp_Y_malloc_block,
5✔
280
                this->debug_info(),
5✔
281
                symbolic::mul(tmp_Y_size, symbolic::size_of_type(tmp_Y_base_type))
5✔
282
            );
5✔
283
            builder.add_computational_memlet(
5✔
284
                tmp_Y_malloc_block, libnode, "_ret", tmp_Y_access, {}, tmp_Y_type, this->debug_info()
5✔
285
            );
5✔
286
        }
5✔
287

288
        // Add loop over batch size
289
        auto n_container = builder.find_new_name("_n");
5✔
290
        builder.add_container(n_container, indvar_type);
5✔
291
        auto n = symbolic::symbol(n_container);
5✔
292
        auto& loop_n = builder.add_map(
5✔
293
            new_sequence,
5✔
294
            n,
5✔
295
            symbolic::Lt(n, this->shape_[0]),
5✔
296
            symbolic::zero(),
5✔
297
            symbolic::add(n, symbolic::one()),
5✔
298
            ScheduleType_Sequential::create(),
5✔
299
            {},
5✔
300
            block->debug_info()
5✔
301
        );
5✔
302
        structured_control_flow::Sequence* current_seq = &loop_n.root();
5✔
303

304
        // Add loops over output dimensions
305
        symbolic::SymbolVec os;
5✔
306
        os.reserve(dims);
5✔
307
        for (size_t i = 0; i < dims; i++) {
15✔
308
            auto o_container = builder.find_new_name("_o");
10✔
309
            builder.add_container(o_container, indvar_type);
10✔
310
            auto o = symbolic::symbol(o_container);
10✔
311
            os.push_back(o);
10✔
312
            auto& loop_o = builder.add_map(
10✔
313
                *current_seq,
10✔
314
                o,
10✔
315
                symbolic::Lt(o, out_shape[i]),
10✔
316
                symbolic::zero(),
10✔
317
                symbolic::add(o, symbolic::one()),
10✔
318
                ScheduleType_Sequential::create(),
10✔
319
                {},
10✔
320
                block->debug_info()
10✔
321
            );
10✔
322
            current_seq = &loop_o.root();
10✔
323
        }
10✔
324

325
        // Add loop over channels
326
        auto c_container = builder.find_new_name("_c");
5✔
327
        builder.add_container(c_container, indvar_type);
5✔
328
        auto c = symbolic::symbol(c_container);
5✔
329
        auto& loop_c = builder.add_map(
5✔
330
            *current_seq,
5✔
331
            c,
5✔
332
            symbolic::Lt(c, this->shape_[1]),
5✔
333
            symbolic::zero(),
5✔
334
            symbolic::add(c, symbolic::one()),
5✔
335
            ScheduleType_Sequential::create(),
5✔
336
            {},
5✔
337
            block->debug_info()
5✔
338
        );
5✔
339
        current_seq = &loop_c.root();
5✔
340

341
        // Add loops over kernel shape
342
        symbolic::SymbolVec ks;
5✔
343
        ks.reserve(dims);
5✔
344
        symbolic::MultiExpression is;
5✔
345
        is.reserve(dims);
5✔
346
        for (size_t i = 0; i < dims; i++) {
15✔
347
            auto k_container = builder.find_new_name("_k");
10✔
348
            builder.add_container(k_container, indvar_type);
10✔
349
            auto k = symbolic::symbol(k_container);
10✔
350
            ks.push_back(k);
10✔
351
            auto i_expr = symbolic::
10✔
352
                add(symbolic::sub(symbolic::mul(os[i], this->strides_[i]), this->pads_[i]),
10✔
353
                    symbolic::mul(k, this->dilations_[i]));
10✔
354
            is.push_back(i_expr);
10✔
355
            auto& loop_k = builder.add_map(
10✔
356
                *current_seq,
10✔
357
                k,
10✔
358
                symbolic::And(symbolic::Lt(k, this->kernel_shape_[i]), symbolic::Lt(i_expr, this->shape_[i + 2])),
10✔
359
                symbolic::zero(),
10✔
360
                symbolic::add(k, symbolic::one()),
10✔
361
                ScheduleType_Sequential::create(),
10✔
362
                {},
10✔
363
                block->debug_info()
10✔
364
            );
10✔
365
            current_seq = &loop_k.root();
10✔
366
        }
10✔
367

368
        // Determine patches subset & tensor type
369
        data_flow::Subset patches_subset;
5✔
370
        patches_subset.push_back(n);
5✔
371
        patches_subset.insert(patches_subset.end(), os.begin(), os.end());
5✔
372
        patches_subset.push_back(c);
5✔
373
        patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
5✔
374
        symbolic::MultiExpression patches_shape;
5✔
375
        patches_shape.push_back(this->shape_[0]);
5✔
376
        patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
5✔
377
        patches_shape.push_back(this->shape_[1]);
5✔
378
        patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
5✔
379
        types::Tensor patches_tensor_type(base_type, patches_shape);
5✔
380

381
        // Determine subset for X
382
        data_flow::Subset subset_X;
5✔
383
        subset_X.push_back(n);
5✔
384
        subset_X.push_back(c);
5✔
385
        subset_X.insert(subset_X.end(), is.begin(), is.end());
5✔
386

387
        // Add copy from X to patches
388
        auto& true_block = builder.add_block(*current_seq, {}, block->debug_info());
5✔
389
        {
5✔
390
            auto& X_access = builder.add_access(true_block, access_X->data(), access_X->debug_info());
5✔
391
            auto& patches_access = builder.add_access(true_block, patches_container, this->debug_info());
5✔
392
            auto& tasklet =
5✔
393
                builder.add_tasklet(true_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
394
            builder.add_computational_memlet(
5✔
395
                true_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
5✔
396
            );
5✔
397
            builder.add_computational_memlet(
5✔
398
                true_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
399
            );
5✔
400
        }
5✔
401

402
        // Add GEMM node
403
        auto& gemm_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
404
        {
5✔
405
            auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
5✔
406
            auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
5✔
407
            auto& W_access = builder.add_access(gemm_block, access_W->data(), access_W->debug_info());
5✔
408
            auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
5✔
409
            auto& tmp_Y_access_in = builder.add_access(gemm_block, tmp_Y_container, access_Y->debug_info());
5✔
410
            auto& tmp_Y_access_out = builder.add_access(gemm_block, tmp_Y_container, access_Y->debug_info());
5✔
411
            symbolic::Expression gemm_m = this->output_channels_;
5✔
412
            symbolic::Expression gemm_n = this->shape_[0];
5✔
413
            symbolic::Expression gemm_k = this->shape_[1];
5✔
414
            for (size_t i = 0; i < dims; i++) {
15✔
415
                gemm_n = symbolic::mul(gemm_n, out_shape[i]);
10✔
416
                gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
10✔
417
            }
10✔
418
            auto& libnode = builder.add_library_node<blas::GEMMNode>(
5✔
419
                gemm_block,
5✔
420
                this->debug_info(),
5✔
421
                blas::ImplementationType_BLAS,
5✔
422
                precision, // precision
5✔
423
                blas::BLAS_Layout::RowMajor, // layout
5✔
424
                blas::BLAS_Transpose::No, // transA
5✔
425
                blas::BLAS_Transpose::Trans, // transB
5✔
426
                gemm_m, // m
5✔
427
                gemm_n, // n
5✔
428
                gemm_k, // k
5✔
429
                gemm_k, // lda
5✔
430
                gemm_k, // ldb
5✔
431
                gemm_n // ldc
5✔
432
            );
5✔
433
            builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
5✔
434
            builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
5✔
435
            builder.add_computational_memlet(
5✔
436
                gemm_block,
5✔
437
                W_access,
5✔
438
                libnode,
5✔
439
                "__A",
5✔
440
                {},
5✔
441
                types::Pointer(types::Scalar(iedge_W->base_type().primitive_type())),
5✔
442
                iedge_W->debug_info()
5✔
443
            );
5✔
444
            builder.add_computational_memlet(
5✔
445
                gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info()
5✔
446
            );
5✔
447
            builder.add_computational_memlet(
5✔
448
                gemm_block, tmp_Y_access_in, libnode, "__C", {}, tmp_Y_type, oedge_Y->debug_info()
5✔
449
            );
5✔
450
            builder.add_computational_memlet(
5✔
451
                gemm_block, libnode, "__C", tmp_Y_access_out, {}, tmp_Y_type, oedge_Y->debug_info()
5✔
452
            );
5✔
453
        }
5✔
454

455
        // Add loop over batch size (again)
456
        auto& loop_n_2 = builder.add_map(
5✔
457
            new_sequence,
5✔
458
            n,
5✔
459
            symbolic::Lt(n, this->shape_[0]),
5✔
460
            symbolic::zero(),
5✔
461
            symbolic::add(n, symbolic::one()),
5✔
462
            ScheduleType_Sequential::create(),
5✔
463
            {},
5✔
464
            block->debug_info()
5✔
465
        );
5✔
466
        current_seq = &loop_n_2.root();
5✔
467

468
        // Add loop over output channels
469
        auto l_container = builder.find_new_name("_l");
5✔
470
        builder.add_container(l_container, indvar_type);
5✔
471
        auto l = symbolic::symbol(l_container);
5✔
472
        auto& loop_l = builder.add_map(
5✔
473
            *current_seq,
5✔
474
            l,
5✔
475
            symbolic::Lt(l, this->output_channels_),
5✔
476
            symbolic::zero(),
5✔
477
            symbolic::add(l, symbolic::one()),
5✔
478
            ScheduleType_Sequential::create(),
5✔
479
            {},
5✔
480
            block->debug_info()
5✔
481
        );
5✔
482
        current_seq = &loop_l.root();
5✔
483

484
        // Add loops over output dimensions (again)
485
        for (size_t i = 0; i < dims; i++) {
15✔
486
            auto o_container = builder.find_new_name("_o");
10✔
487
            builder.add_container(o_container, indvar_type);
10✔
488
            auto o = symbolic::symbol(o_container);
10✔
489
            auto& loop_o = builder.add_map(
10✔
490
                *current_seq,
10✔
491
                o,
10✔
492
                symbolic::Lt(o, out_shape[i]),
10✔
493
                symbolic::zero(),
10✔
494
                symbolic::add(o, symbolic::one()),
10✔
495
                ScheduleType_Sequential::create(),
10✔
496
                {},
10✔
497
                block->debug_info()
10✔
498
            );
10✔
499
            current_seq = &loop_o.root();
10✔
500
            os[i] = o;
10✔
501
        }
10✔
502

503
        // Add transposed copy from temporary GEMM output to Y + add bias if available
504
        data_flow::Subset tmp_Y_subset;
5✔
505
        tmp_Y_subset.push_back(l);
5✔
506
        tmp_Y_subset.push_back(n);
5✔
507
        tmp_Y_subset.insert(tmp_Y_subset.end(), os.begin(), os.end());
5✔
508
        symbolic::MultiExpression tmp_Y_shape;
5✔
509
        tmp_Y_shape.push_back(this->output_channels_);
5✔
510
        tmp_Y_shape.push_back(this->shape_[0]);
5✔
511
        tmp_Y_shape.insert(tmp_Y_shape.end(), out_shape.begin(), out_shape.end());
5✔
512
        types::Tensor tmp_Y_tensor_type(tmp_Y_base_type, tmp_Y_shape);
5✔
513
        data_flow::Subset Y_subset;
5✔
514
        Y_subset.push_back(n);
5✔
515
        Y_subset.push_back(l);
5✔
516
        Y_subset.insert(Y_subset.end(), os.begin(), os.end());
5✔
517
        auto& transpose_block = builder.add_block(*current_seq, {}, block->debug_info());
5✔
518
        if (has_bias) {
5✔
NEW
519
            auto& tmp_Y_access = builder.add_access(transpose_block, tmp_Y_container, this->debug_info());
×
NEW
520
            auto& B_access = builder.add_access(transpose_block, access_B->data(), access_B->debug_info());
×
NEW
521
            auto& Y_access = builder.add_access(transpose_block, access_Y->data(), access_Y->debug_info());
×
NEW
522
            auto& tasklet = builder.add_tasklet(
×
NEW
523
                transpose_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info()
×
NEW
524
            );
×
NEW
525
            builder.add_computational_memlet(
×
NEW
526
                transpose_block, tmp_Y_access, tasklet, "_in1", tmp_Y_subset, tmp_Y_tensor_type, this->debug_info()
×
NEW
527
            );
×
NEW
528
            builder.add_computational_memlet(
×
NEW
529
                transpose_block, B_access, tasklet, "_in2", {l}, iedge_B->base_type(), iedge_B->debug_info()
×
NEW
530
            );
×
NEW
531
            builder.add_computational_memlet(
×
NEW
532
                transpose_block, tasklet, "_out", Y_access, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
NEW
533
            );
×
534
        } else {
5✔
535
            auto& tmp_Y_access = builder.add_access(transpose_block, tmp_Y_container, this->debug_info());
5✔
536
            auto& Y_access = builder.add_access(transpose_block, access_Y->data(), access_Y->debug_info());
5✔
537
            auto& tasklet =
5✔
538
                builder
5✔
539
                    .add_tasklet(transpose_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
540
            builder.add_computational_memlet(
5✔
541
                transpose_block, tmp_Y_access, tasklet, "_in", tmp_Y_subset, tmp_Y_tensor_type, this->debug_info()
5✔
542
            );
5✔
543
            builder.add_computational_memlet(
5✔
544
                transpose_block, tasklet, "_out", Y_access, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
5✔
545
            );
5✔
546
        }
5✔
547

548
        // Add free for patches container
549
        auto& patches_free_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
550
        {
5✔
551
            auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
552
            auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
553
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
5✔
554
            builder.add_computational_memlet(
5✔
555
                patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
5✔
556
            );
5✔
557
            builder.add_computational_memlet(
5✔
558
                patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
5✔
559
            );
5✔
560
        }
5✔
561

562
        // Add free for temporary GEMM output
563
        auto& tmp_Y_free_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
564
        {
5✔
565
            auto& tmp_Y_access_in = builder.add_access(tmp_Y_free_block, tmp_Y_container, this->debug_info());
5✔
566
            auto& tmp_Y_access_out = builder.add_access(tmp_Y_free_block, tmp_Y_container, this->debug_info());
5✔
567
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(tmp_Y_free_block, this->debug_info());
5✔
568
            builder.add_computational_memlet(
5✔
569
                tmp_Y_free_block, tmp_Y_access_in, libnode, "_ptr", {}, tmp_Y_type, this->debug_info()
5✔
570
            );
5✔
571
            builder.add_computational_memlet(
5✔
572
                tmp_Y_free_block, libnode, "_ptr", tmp_Y_access_out, {}, tmp_Y_type, this->debug_info()
5✔
573
            );
5✔
574
        }
5✔
575

576
        /* ===== No groups ====================================================================== */
577

578
    } else {
5✔
579
        /* ===== Groups ========================================================================= */
580

NEW
581
        auto in_channels = symbolic::div(this->shape_[1], this->group_);
×
NEW
582
        auto out_channels = symbolic::div(this->output_channels_, this->group_);
×
583

584
        // Add loop over batch size
NEW
585
        auto n_container = builder.find_new_name("_n");
×
NEW
586
        builder.add_container(n_container, indvar_type);
×
NEW
587
        auto n = symbolic::symbol(n_container);
×
NEW
588
        auto& loop_n = builder.add_map(
×
NEW
589
            new_sequence,
×
NEW
590
            n,
×
NEW
591
            symbolic::Lt(n, this->shape_[0]),
×
UNCOV
592
            symbolic::zero(),
×
NEW
593
            symbolic::add(n, symbolic::one()),
×
NEW
594
            ScheduleType_Sequential::create(),
×
UNCOV
595
            {},
×
NEW
596
            block->debug_info()
×
UNCOV
597
        );
×
598

599
        // Add loop over groups
NEW
600
        auto g_container = builder.find_new_name("_g");
×
NEW
601
        builder.add_container(g_container, indvar_type);
×
NEW
602
        auto g = symbolic::symbol(g_container);
×
NEW
603
        auto& loop_g = builder.add_map(
×
NEW
604
            loop_n.root(),
×
NEW
605
            g,
×
NEW
606
            symbolic::Lt(g, this->group_),
×
UNCOV
607
            symbolic::zero(),
×
NEW
608
            symbolic::add(g, symbolic::one()),
×
NEW
609
            ScheduleType_Sequential::create(),
×
UNCOV
610
            {},
×
NEW
611
            block->debug_info()
×
UNCOV
612
        );
×
613

614
        // Add patches container with malloc
NEW
615
        symbolic::Expression patches_size = in_channels;
×
NEW
616
        for (size_t i = 0; i < dims; i++) {
×
NEW
617
            patches_size = symbolic::mul(patches_size, symbolic::mul(this->kernel_shape_[i], out_shape[i]));
×
NEW
618
        }
×
NEW
619
        types::Pointer patches_type(base_type);
×
NEW
620
        auto patches_container = builder.find_new_name("_patches");
×
NEW
621
        builder.add_container(patches_container, patches_type);
×
NEW
622
        auto& patches_malloc_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
NEW
623
        {
×
NEW
624
            auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
×
NEW
625
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
×
NEW
626
                patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
×
NEW
627
            );
×
NEW
628
            builder.add_computational_memlet(
×
NEW
629
                patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
×
NEW
630
            );
×
NEW
631
        }
×
632

633
        // Memset patches to zero
NEW
634
        auto& patches_memset_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
NEW
635
        {
×
NEW
636
            auto& patches_access = builder.add_access(patches_memset_block, patches_container, this->debug_info());
×
NEW
637
            auto& libnode = builder.add_library_node<stdlib::MemsetNode>(
×
NEW
638
                patches_memset_block,
×
NEW
639
                this->debug_info(),
×
NEW
640
                symbolic::zero(),
×
NEW
641
                symbolic::mul(patches_size, symbolic::size_of_type(base_type))
×
NEW
642
            );
×
NEW
643
            builder.add_computational_memlet(
×
NEW
644
                patches_memset_block, libnode, "_ptr", patches_access, {}, patches_type, this->debug_info()
×
NEW
645
            );
×
NEW
646
        }
×
647

648
        // Add loops over output dimensions
NEW
649
        structured_control_flow::Sequence* current_seq = &loop_g.root();
×
NEW
650
        symbolic::SymbolVec os;
×
NEW
651
        os.reserve(dims);
×
NEW
652
        for (size_t i = 0; i < dims; i++) {
×
NEW
653
            auto o_container = builder.find_new_name("_o");
×
NEW
654
            builder.add_container(o_container, indvar_type);
×
NEW
655
            auto o = symbolic::symbol(o_container);
×
NEW
656
            os.push_back(o);
×
NEW
657
            auto& loop_o = builder.add_map(
×
NEW
658
                *current_seq,
×
NEW
659
                o,
×
NEW
660
                symbolic::Lt(o, out_shape[i]),
×
NEW
661
                symbolic::zero(),
×
NEW
662
                symbolic::add(o, symbolic::one()),
×
NEW
663
                ScheduleType_Sequential::create(),
×
NEW
664
                {},
×
NEW
665
                block->debug_info()
×
NEW
666
            );
×
NEW
667
            current_seq = &loop_o.root();
×
NEW
668
        }
×
669

670
        // Add loop over channels
NEW
671
        auto c_container = builder.find_new_name("_c");
×
NEW
672
        builder.add_container(c_container, indvar_type);
×
NEW
673
        auto c = symbolic::symbol(c_container);
×
NEW
674
        auto& loop_c = builder.add_map(
×
NEW
675
            *current_seq,
×
NEW
676
            c,
×
NEW
677
            symbolic::Lt(c, in_channels),
×
NEW
678
            symbolic::zero(),
×
NEW
679
            symbolic::add(c, symbolic::one()),
×
NEW
680
            ScheduleType_Sequential::create(),
×
NEW
681
            {},
×
NEW
682
            block->debug_info()
×
UNCOV
683
        );
×
NEW
684
        current_seq = &loop_c.root();
×
685

686
        // Add loops over kernel shape
NEW
687
        symbolic::SymbolVec ks;
×
NEW
688
        ks.reserve(dims);
×
NEW
689
        symbolic::MultiExpression is;
×
NEW
690
        is.reserve(dims);
×
NEW
691
        for (size_t i = 0; i < dims; i++) {
×
NEW
692
            auto k_container = builder.find_new_name("_k");
×
NEW
693
            builder.add_container(k_container, indvar_type);
×
NEW
694
            auto k = symbolic::symbol(k_container);
×
NEW
695
            ks.push_back(k);
×
NEW
696
            auto i_expr = symbolic::
×
NEW
697
                add(symbolic::sub(symbolic::mul(os[i], this->strides_[i]), this->pads_[i]),
×
NEW
698
                    symbolic::mul(k, this->dilations_[i]));
×
NEW
699
            is.push_back(i_expr);
×
NEW
700
            auto& loop_k = builder.add_map(
×
NEW
701
                *current_seq,
×
NEW
702
                k,
×
NEW
703
                symbolic::And(symbolic::Lt(k, this->kernel_shape_[i]), symbolic::Lt(i_expr, this->shape_[i + 2])),
×
NEW
704
                symbolic::zero(),
×
NEW
705
                symbolic::add(k, symbolic::one()),
×
NEW
706
                ScheduleType_Sequential::create(),
×
NEW
707
                {},
×
NEW
708
                block->debug_info()
×
NEW
709
            );
×
NEW
710
            current_seq = &loop_k.root();
×
NEW
711
        }
×
712

713
        // Determine patches subset & tensor type
NEW
714
        data_flow::Subset patches_subset;
×
NEW
715
        patches_subset.push_back(c);
×
NEW
716
        patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
×
NEW
717
        patches_subset.insert(patches_subset.end(), os.begin(), os.end());
×
NEW
718
        symbolic::MultiExpression patches_shape;
×
NEW
719
        patches_shape.push_back(in_channels);
×
NEW
720
        patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
×
NEW
721
        patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
×
NEW
722
        types::Tensor patches_tensor_type(base_type, patches_shape);
×
723

724
        // Determine subset for X
NEW
725
        data_flow::Subset subset_X;
×
NEW
726
        subset_X.push_back(n);
×
NEW
727
        subset_X.push_back(symbolic::add(symbolic::mul(in_channels, g), c));
×
NEW
728
        subset_X.insert(subset_X.end(), is.begin(), is.end());
×
729

730
        // Add copy from X to patches
NEW
731
        auto& true_block = builder.add_block(*current_seq, {}, block->debug_info());
×
NEW
732
        {
×
NEW
733
            auto& X_access = builder.add_access(true_block, access_X->data(), access_X->debug_info());
×
NEW
734
            auto& patches_access = builder.add_access(true_block, patches_container, this->debug_info());
×
NEW
735
            auto& tasklet =
×
NEW
736
                builder.add_tasklet(true_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
×
NEW
737
            builder.add_computational_memlet(
×
NEW
738
                true_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
×
NEW
739
            );
×
NEW
740
            builder.add_computational_memlet(
×
NEW
741
                true_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
×
NEW
742
            );
×
NEW
743
        }
×
744

745
        // Add reference to W
NEW
746
        auto ref_W_container = builder.find_new_name("_ref_W");
×
NEW
747
        types::Scalar ref_W_base_type(builder.subject().type(access_W->data()).primitive_type());
×
NEW
748
        types::Pointer ref_W_type(ref_W_base_type);
×
NEW
749
        builder.add_container(ref_W_container, ref_W_type);
×
NEW
750
        auto ref_W_subset = symbolic::mul(symbolic::mul(out_channels, g), in_channels);
×
NEW
751
        for (size_t i = 0; i < dims; i++) {
×
NEW
752
            ref_W_subset = symbolic::mul(ref_W_subset, this->kernel_shape_[i]);
×
NEW
753
        }
×
NEW
754
        auto& ref_W_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
NEW
755
        {
×
NEW
756
            auto& W_access = builder.add_access(ref_W_block, access_W->data(), access_W->debug_info());
×
NEW
757
            auto& ref_W_access = builder.add_access(ref_W_block, ref_W_container, access_W->debug_info());
×
NEW
758
            builder.add_reference_memlet(ref_W_block, W_access, ref_W_access, {ref_W_subset}, ref_W_type);
×
NEW
759
        }
×
760

761
        // Add reference to Y
NEW
762
        auto ref_Y_container = builder.find_new_name("_ref_Y");
×
NEW
763
        types::Scalar ref_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
×
NEW
764
        types::Pointer ref_Y_type(ref_Y_base_type);
×
NEW
765
        builder.add_container(ref_Y_container, ref_Y_type);
×
NEW
766
        auto ref_Y_subset = symbolic::add(symbolic::mul(this->output_channels_, n), symbolic::mul(out_channels, g));
×
NEW
767
        for (size_t i = 0; i < dims; i++) {
×
NEW
768
            ref_Y_subset = symbolic::mul(ref_Y_subset, out_shape[i]);
×
NEW
769
        }
×
NEW
770
        auto& ref_Y_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
NEW
771
        {
×
NEW
772
            auto& Y_access = builder.add_access(ref_Y_block, access_Y->data(), access_Y->debug_info());
×
NEW
773
            auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, access_Y->debug_info());
×
NEW
774
            builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_subset}, ref_Y_type);
×
NEW
775
        }
×
776

777
        // Add GEMM node
NEW
778
        auto& gemm_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
NEW
779
        {
×
NEW
780
            auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
×
NEW
781
            auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
×
NEW
782
            auto& ref_W_access = builder.add_access(gemm_block, ref_W_container, access_W->debug_info());
×
NEW
783
            auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
×
NEW
784
            auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
×
NEW
785
            auto& ref_Y_access_out = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
×
NEW
786
            symbolic::Expression gemm_m = out_channels;
×
NEW
787
            symbolic::Expression gemm_n = symbolic::one();
×
NEW
788
            symbolic::Expression gemm_k = in_channels;
×
NEW
789
            for (size_t i = 0; i < dims; i++) {
×
NEW
790
                gemm_n = symbolic::mul(gemm_n, out_shape[i]);
×
NEW
791
                gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
×
NEW
792
            }
×
NEW
793
            auto& libnode = builder.add_library_node<blas::GEMMNode>(
×
NEW
794
                gemm_block,
×
NEW
795
                this->debug_info(),
×
NEW
796
                blas::ImplementationType_BLAS,
×
NEW
797
                precision, // precision
×
NEW
798
                blas::BLAS_Layout::RowMajor, // layout
×
NEW
799
                blas::BLAS_Transpose::No, // transA
×
NEW
800
                blas::BLAS_Transpose::No, // transB
×
NEW
801
                gemm_m, // m
×
NEW
802
                gemm_n, // n
×
NEW
803
                gemm_k, // k
×
NEW
804
                gemm_k, // lda
×
NEW
805
                gemm_n, // ldb
×
NEW
806
                gemm_n // ldc
×
NEW
807
            );
×
NEW
808
            builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
×
NEW
809
            builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
×
NEW
810
            builder
×
NEW
811
                .add_computational_memlet(gemm_block, ref_W_access, libnode, "__A", {}, ref_W_type, iedge_W->debug_info());
×
NEW
812
            builder.add_computational_memlet(
×
NEW
813
                gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info()
×
NEW
814
            );
×
NEW
815
            builder.add_computational_memlet(
×
NEW
816
                gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, oedge_Y->debug_info()
×
NEW
817
            );
×
NEW
818
            builder.add_computational_memlet(
×
NEW
819
                gemm_block, libnode, "__C", ref_Y_access_out, {}, ref_Y_type, oedge_Y->debug_info()
×
NEW
820
            );
×
NEW
821
        }
×
822

823
        // Add bias if available
NEW
824
        if (has_bias) {
×
825
            // Add loop over output channels
NEW
826
            auto l_container = builder.find_new_name("_l");
×
NEW
827
            builder.add_container(l_container, indvar_type);
×
NEW
828
            auto l = symbolic::symbol(l_container);
×
NEW
829
            auto& loop_l = builder.add_map(
×
NEW
830
                loop_g.root(),
×
NEW
831
                l,
×
NEW
832
                symbolic::Lt(l, out_channels),
×
NEW
833
                symbolic::zero(),
×
NEW
834
                symbolic::add(l, symbolic::one()),
×
NEW
835
                ScheduleType_Sequential::create(),
×
NEW
836
                {},
×
NEW
837
                block->debug_info()
×
NEW
838
            );
×
NEW
839
            current_seq = &loop_l.root();
×
840

841
            // Add loops over output dimensions (again)
NEW
842
            for (size_t i = 0; i < dims; i++) {
×
NEW
843
                auto o_container = builder.find_new_name("_o");
×
NEW
844
                builder.add_container(o_container, indvar_type);
×
NEW
845
                auto o = symbolic::symbol(o_container);
×
NEW
846
                auto& loop_o = builder.add_map(
×
NEW
847
                    *current_seq,
×
NEW
848
                    o,
×
NEW
849
                    symbolic::Lt(o, out_shape[i]),
×
NEW
850
                    symbolic::zero(),
×
NEW
851
                    symbolic::add(o, symbolic::one()),
×
NEW
852
                    ScheduleType_Sequential::create(),
×
NEW
853
                    {},
×
NEW
854
                    block->debug_info()
×
NEW
855
                );
×
NEW
856
                current_seq = &loop_o.root();
×
NEW
857
                os[i] = o;
×
NEW
858
            }
×
859

860
            // Add bias to Y
NEW
861
            data_flow::Subset Y_subset;
×
NEW
862
            Y_subset.push_back(n);
×
NEW
863
            Y_subset.push_back(symbolic::add(symbolic::mul(out_channels, g), l));
×
NEW
864
            Y_subset.insert(Y_subset.end(), os.begin(), os.end());
×
NEW
865
            auto B_subset = symbolic::add(symbolic::mul(out_channels, g), l);
×
NEW
866
            auto& bias_block = builder.add_block(*current_seq, {}, block->debug_info());
×
NEW
867
            {
×
NEW
868
                auto& B_access = builder.add_access(bias_block, access_B->data(), access_B->debug_info());
×
NEW
869
                auto& Y_access_in = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
NEW
870
                auto& Y_access_out = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
NEW
871
                auto& tasklet = builder.add_tasklet(
×
NEW
872
                    bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info()
×
NEW
873
                );
×
NEW
874
                builder.add_computational_memlet(
×
NEW
875
                    bias_block, Y_access_in, tasklet, "_in1", Y_subset, oedge_Y->base_type(), this->debug_info()
×
NEW
876
                );
×
NEW
877
                builder.add_computational_memlet(
×
NEW
878
                    bias_block, B_access, tasklet, "_in2", {B_subset}, iedge_B->base_type(), iedge_B->debug_info()
×
NEW
879
                );
×
NEW
880
                builder.add_computational_memlet(
×
NEW
881
                    bias_block, tasklet, "_out", Y_access_out, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
NEW
882
                );
×
NEW
883
            }
×
NEW
884
        }
×
885

886
        // Add free for patches container
NEW
887
        auto& patches_free_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
NEW
888
        {
×
NEW
889
            auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
×
NEW
890
            auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
×
NEW
891
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
×
NEW
892
            builder.add_computational_memlet(
×
NEW
893
                patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
×
NEW
894
            );
×
NEW
895
            builder.add_computational_memlet(
×
NEW
896
                patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
×
NEW
897
            );
×
NEW
898
        }
×
899

900
        /* ===== Groups ========================================================================= */
UNCOV
901
    }
×
902

903
    // Clean up the original block
904
    builder.remove_memlet(*block, *iedge_X);
5✔
905
    builder.remove_memlet(*block, *iedge_W);
5✔
906
    if (has_bias) {
5✔
NEW
907
        builder.remove_memlet(*block, *iedge_B);
×
NEW
908
    }
×
909
    builder.remove_memlet(*block, *oedge_Y);
5✔
910
    builder.remove_node(*block, *access_X);
5✔
911
    builder.remove_node(*block, *access_W);
5✔
912
    if (has_bias) {
5✔
NEW
913
        builder.remove_node(*block, *access_B);
×
NEW
914
    }
×
915
    builder.remove_node(*block, *access_Y);
5✔
916
    builder.remove_node(*block, *this);
5✔
917
    builder.remove_child(*block_parent, block_index + 1);
5✔
918

919
    return true;
5✔
920
}
5✔
921

922
symbolic::SymbolSet ConvNode::symbols() const {
×
923
    symbolic::SymbolSet syms;
×
924

925
    for (auto& expr : shape_) {
×
926
        for (auto& atom : symbolic::atoms(expr)) {
×
927
            syms.insert(atom);
×
928
        }
×
929
    }
×
930
    for (auto& expr : kernel_shape_) {
×
931
        for (auto& atom : symbolic::atoms(expr)) {
×
932
            syms.insert(atom);
×
933
        }
×
934
    }
×
935
    for (auto& expr : strides_) {
×
936
        for (auto& atom : symbolic::atoms(expr)) {
×
937
            syms.insert(atom);
×
938
        }
×
939
    }
×
940
    for (auto& expr : pads_) {
×
941
        for (auto& atom : symbolic::atoms(expr)) {
×
942
            syms.insert(atom);
×
943
        }
×
944
    }
×
945
    for (auto& expr : dilations_) {
×
946
        for (auto& atom : symbolic::atoms(expr)) {
×
947
            syms.insert(atom);
×
948
        }
×
949
    }
×
950
    for (auto& atom : symbolic::atoms(output_channels_)) {
×
951
        syms.insert(atom);
×
952
    }
×
953
    for (auto& atom : symbolic::atoms(group_)) {
×
954
        syms.insert(atom);
×
955
    }
×
956

957
    return syms;
×
958
}
×
959

960
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
961
    for (auto& expr : shape_) {
×
962
        expr = symbolic::subs(expr, old_expression, new_expression);
×
963
    }
×
964
    for (auto& expr : kernel_shape_) {
×
965
        expr = symbolic::subs(expr, old_expression, new_expression);
×
966
    }
×
967
    for (auto& expr : strides_) {
×
968
        expr = symbolic::subs(expr, old_expression, new_expression);
×
969
    }
×
970
    for (auto& expr : pads_) {
×
971
        expr = symbolic::subs(expr, old_expression, new_expression);
×
972
    }
×
973
    for (auto& expr : dilations_) {
×
974
        expr = symbolic::subs(expr, old_expression, new_expression);
×
975
    }
×
976
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
977
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
978
}
×
979

980
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
981
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
982
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
983
        element_id,
1✔
984
        this->debug_info(),
1✔
985
        vertex,
1✔
986
        parent,
1✔
987
        shape_,
1✔
988
        kernel_shape_,
1✔
989
        strides_,
1✔
990
        pads_,
1✔
991
        dilations_,
1✔
992
        output_channels_,
1✔
993
        group_
1✔
994
    ));
1✔
995
}
1✔
996

997
std::string ConvNode::toStr() const {
×
NEW
998
    std::stringstream result;
×
NEW
999
    result << "Conv(shape=[";
×
1000
    for (size_t i = 0; i < shape_.size(); ++i) {
×
NEW
1001
        if (i > 0) {
×
NEW
1002
            result << ", ";
×
NEW
1003
        }
×
NEW
1004
        result << shape_[i]->__str__();
×
1005
    }
×
NEW
1006
    result << "], kernel_shape=[";
×
1007
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
NEW
1008
        if (i > 0) {
×
NEW
1009
            result << ", ";
×
NEW
1010
        }
×
NEW
1011
        result << kernel_shape_[i]->__str__();
×
1012
    }
×
NEW
1013
    result << "], strides=[";
×
1014
    for (size_t i = 0; i < strides_.size(); ++i) {
×
NEW
1015
        if (i > 0) {
×
NEW
1016
            result << ", ";
×
NEW
1017
        }
×
NEW
1018
        result << strides_[i]->__str__();
×
NEW
1019
    }
×
NEW
1020
    result << "], pads=[";
×
NEW
1021
    for (size_t i = 0; i < pads_.size(); ++i) {
×
NEW
1022
        if (i > 0) {
×
NEW
1023
            result << ", ";
×
NEW
1024
        }
×
NEW
1025
        result << pads_[i]->__str__();
×
NEW
1026
    }
×
NEW
1027
    result << "], dilations=[";
×
NEW
1028
    for (size_t i = 0; i < dilations_.size(); ++i) {
×
NEW
1029
        if (i > 0) {
×
NEW
1030
            result << ", ";
×
NEW
1031
        }
×
NEW
1032
        result << dilations_[i]->__str__();
×
1033
    }
×
NEW
1034
    result << "], output_channels=" + output_channels_->__str__();
×
NEW
1035
    result << ", group=" + group_->__str__() + ")";
×
NEW
1036
    return result.str();
×
UNCOV
1037
}
×
1038

1039
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
1040
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
1041
    nlohmann::json j;
×
1042

1043
    j["code"] = conv_node.code().value();
×
1044

1045
    serializer::JSONSerializer serializer;
×
1046

1047
    j["shape"] = nlohmann::json::array();
×
1048
    for (auto& dim : conv_node.shape()) {
×
1049
        j["shape"].push_back(serializer.expression(dim));
×
1050
    }
×
1051

1052
    j["kernel_shape"] = nlohmann::json::array();
×
1053
    for (auto& dim : conv_node.kernel_shape()) {
×
1054
        j["kernel_shape"].push_back(serializer.expression(dim));
×
1055
    }
×
1056

1057
    j["strides"] = nlohmann::json::array();
×
1058
    for (auto& stride : conv_node.strides()) {
×
1059
        j["strides"].push_back(serializer.expression(stride));
×
1060
    }
×
1061

1062
    j["pads"] = nlohmann::json::array();
×
1063
    for (auto& pad : conv_node.pads()) {
×
1064
        j["pads"].push_back(serializer.expression(pad));
×
1065
    }
×
1066

1067
    j["dilations"] = nlohmann::json::array();
×
1068
    for (auto& dilation : conv_node.dilations()) {
×
1069
        j["dilations"].push_back(serializer.expression(dilation));
×
1070
    }
×
1071

1072
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
1073
    j["group"] = serializer.expression(conv_node.group());
×
1074

1075
    return j;
×
1076
}
×
1077

1078
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
1079
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
1080
) {
×
1081
    assert(j.contains("element_id"));
×
1082
    assert(j.contains("code"));
×
1083
    assert(j.contains("debug_info"));
×
1084
    assert(j.contains("kernel_shape"));
×
1085

1086
    std::vector<symbolic::Expression> shape;
×
1087
    if (j.contains("shape")) {
×
1088
        for (const auto& dim : j["shape"]) {
×
1089
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
1090
        }
×
1091
    }
×
1092

1093
    std::vector<symbolic::Expression> kernel_shape;
×
1094
    for (const auto& dim : j["kernel_shape"]) {
×
1095
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
1096
    }
×
1097

1098
    std::vector<symbolic::Expression> strides;
×
1099
    if (j.contains("strides")) {
×
1100
        for (const auto& stride : j["strides"]) {
×
1101
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
1102
        }
×
1103
    }
×
1104

1105
    std::vector<symbolic::Expression> pads;
×
1106
    if (j.contains("pads")) {
×
1107
        for (const auto& pad : j["pads"]) {
×
1108
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
1109
        }
×
1110
    }
×
1111

1112
    std::vector<symbolic::Expression> dilations;
×
1113
    if (j.contains("dilations")) {
×
1114
        for (const auto& dilation : j["dilations"]) {
×
1115
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
1116
        }
×
1117
    }
×
1118

1119
    symbolic::Expression output_channels = symbolic::one();
×
1120
    if (j.contains("output_channels")) {
×
1121
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
1122
    }
×
1123

1124
    symbolic::Expression group = symbolic::one();
×
1125
    if (j.contains("group")) {
×
1126
        group = symbolic::parse(j["group"].get<std::string>());
×
1127
    }
×
1128

1129
    sdfg::serializer::JSONSerializer serializer;
×
1130
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
1131

1132
    return builder.add_library_node<
×
1133
        ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, output_channels, group);
×
1134
}
×
1135

1136
} // namespace tensor
1137
} // namespace math
1138
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc