• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 24215882789

09 Apr 2026 10:12PM UTC coverage: 64.375% (-0.007%) from 64.382%
24215882789

Pull #668

github

web-flow
Merge 6f7f28e8f into bb3981349
Pull Request #668: Offload Memset to GPU

249 of 381 new or added lines in 18 files covered. (65.35%)

189 existing lines in 2 files now uncovered.

29942 of 46512 relevant lines covered (64.37%)

584.42 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

45.18
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4
#include <sstream>
5

6
#include "sdfg/analysis/analysis.h"
7
#include "sdfg/builder/structured_sdfg_builder.h"
8
#include "sdfg/data_flow/access_node.h"
9
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
10
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
11
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
12
#include "sdfg/data_flow/library_nodes/stdlib/memset.h"
13
#include "sdfg/data_flow/memlet.h"
14
#include "sdfg/data_flow/tasklet.h"
15
#include "sdfg/exceptions.h"
16
#include "sdfg/structured_control_flow/block.h"
17
#include "sdfg/structured_control_flow/map.h"
18
#include "sdfg/structured_control_flow/sequence.h"
19
#include "sdfg/symbolic/symbolic.h"
20
#include "sdfg/types/pointer.h"
21
#include "sdfg/types/scalar.h"
22
#include "sdfg/types/tensor.h"
23
#include "sdfg/types/type.h"
24

25
#include "sdfg/analysis/scope_analysis.h"
26
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
27
#include "symengine/integer.h"
28
#include "symengine/symengine_rcp.h"
29

30
namespace sdfg {
31
namespace math {
32
namespace tensor {
33

34
ConvNode::ConvNode(
35
    size_t element_id,
36
    const DebugInfo& debug_info,
37
    const graph::Vertex vertex,
38
    data_flow::DataFlowGraph& parent,
39
    const std::vector<symbolic::Expression>& shape,
40
    const std::vector<symbolic::Expression>& kernel_shape,
41
    const std::vector<symbolic::Expression>& strides,
42
    const std::vector<symbolic::Expression>& pads,
43
    const std::vector<symbolic::Expression>& dilations,
44
    symbolic::Expression output_channels,
45
    symbolic::Expression group
46
)
47
    : TensorNode(
22✔
48
          element_id,
22✔
49
          debug_info,
22✔
50
          vertex,
22✔
51
          parent,
22✔
52
          LibraryNodeType_Conv,
22✔
53
          {"Y"},
22✔
54
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
22✔
55
          data_flow::ImplementationType_NONE
22✔
56
      ),
22✔
57
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations),
22✔
58
      output_channels_(output_channels), group_(group) {}
22✔
59

60
void ConvNode::validate(const Function& function) const {
17✔
61
    TensorNode::validate(function);
17✔
62

63
    auto& graph = this->get_parent();
17✔
64

65
    // Custom validation for ConvNode that handles optional bias input
66
    // We expect X, W as required inputs and optionally B (bias)
67

68
    // Collect all input edges by connector name
69
    std::map<std::string, const data_flow::Memlet*> input_edges;
17✔
70
    for (auto& iedge : graph.in_edges(*this)) {
34✔
71
        input_edges[iedge.dst_conn()] = &iedge;
34✔
72
    }
34✔
73

74
    // Check that required inputs X and W are present
75
    if (input_edges.find("X") == input_edges.end()) {
17✔
76
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
77
    }
×
78
    if (input_edges.find("W") == input_edges.end()) {
17✔
79
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
80
    }
×
81

82
    // Validate that parameters are not empty
83
    if (shape_.empty()) {
17✔
84
        throw InvalidSDFGException("ConvNode shape cannot be empty");
×
85
    }
×
86
    if (kernel_shape_.empty()) {
17✔
87
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
88
    }
×
89
    if (strides_.empty()) {
17✔
90
        throw InvalidSDFGException("ConvNode strides cannot be empty");
×
91
    }
×
92
    if (pads_.empty()) {
17✔
93
        throw InvalidSDFGException("ConvNode pads cannot be empty");
×
94
    }
×
95
    if (dilations_.empty()) {
17✔
96
        throw InvalidSDFGException("ConvNode dilations cannot be empty");
×
97
    }
×
98

99
    // Validate consistent dimensions
100
    size_t spatial_dims = kernel_shape_.size();
17✔
101

102
    if (shape_.size() != spatial_dims + 2) {
17✔
103
        throw InvalidSDFGException("ConvNode shape must match kernel spatial dimensions + 2");
×
104
    }
×
105

106
    if (strides_.size() != spatial_dims) {
17✔
107
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
108
    }
1✔
109

110
    if (pads_.size() != 2 * spatial_dims) {
16✔
111
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
112
    }
1✔
113

114
    if (dilations_.size() != spatial_dims) {
15✔
115
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
116
    }
×
117

118
    // Validate groups
119
    if (SymEngine::is_a<SymEngine::Integer>(*this->group_)) {
15✔
120
        auto group_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->group_)->as_int();
15✔
121
        if (SymEngine::is_a<SymEngine::Integer>(*this->shape_[1])) {
15✔
122
            auto input_channels_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->shape_[1])->as_int();
15✔
123
            if (input_channels_int % group_int != 0) {
15✔
124
                throw InvalidSDFGException("ConvNode input channels must be divisible by groups");
×
125
            }
×
126
        }
15✔
127
        if (SymEngine::is_a<SymEngine::Integer>(*this->output_channels_)) {
15✔
128
            auto output_channels_int =
15✔
129
                SymEngine::rcp_static_cast<const SymEngine::Integer>(this->output_channels_)->as_int();
15✔
130
            if (output_channels_int % group_int != 0) {
15✔
131
                throw InvalidSDFGException("ConvNode output channels must be divisible by groups");
×
132
            }
×
133
        }
15✔
134
    }
15✔
135
}
15✔
136

137
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
138
    // Validate nodes are standalone in the data flow graph
139
    auto& dfg = this->get_parent();
5✔
140
    if ((dfg.nodes().size() != 4 || dfg.edges().size() != 3) && (dfg.nodes().size() != 5 || dfg.edges().size() != 4)) {
5✔
141
        return false;
×
142
    }
×
143

144
    // Get edges
145
    auto iedges = dfg.in_edges_by_connector(*this);
5✔
146
    auto oedges = dfg.out_edges_by_connector(*this);
5✔
147
    if (iedges.size() != 3 || oedges.size() != 1) {
5✔
148
        return false;
×
149
    }
×
150
    auto* iedge_X = iedges.at(0);
5✔
151
    auto* iedge_W = iedges.at(1);
5✔
152
    auto* iedge_B = iedges.at(2);
5✔
153
    auto* oedge_Y = oedges.at(0);
5✔
154
    if (!iedge_X || !iedge_W || !oedge_Y) {
5✔
155
        return false;
×
156
    }
×
157
    bool has_bias = iedge_B != nullptr;
5✔
158

159
    // Get access nodes
160
    auto* access_X = dynamic_cast<data_flow::AccessNode*>(&iedge_X->src());
5✔
161
    auto* access_W = dynamic_cast<data_flow::AccessNode*>(&iedge_W->src());
5✔
162
    auto* access_B = (has_bias ? dynamic_cast<data_flow::AccessNode*>(&iedge_B->src()) : nullptr);
5✔
163
    auto* access_Y = dynamic_cast<data_flow::AccessNode*>(&oedge_Y->dst());
5✔
164
    if (!access_X || !access_W || (has_bias && !access_B) || !access_Y) {
5✔
165
        return false;
×
166
    }
×
167

168
    // Get block & its parent
169
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
5✔
170
    if (!block) {
5✔
171
        return false;
×
172
    }
×
173
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
174
    auto* block_parent = dynamic_cast<structured_control_flow::Sequence*>(scope_analysis.parent_scope(block));
5✔
175
    if (!block_parent) {
5✔
176
        return false;
×
177
    }
×
178
    size_t block_index = block_parent->index(*block);
5✔
179
    if (block_index >= block_parent->size()) {
5✔
180
        return false;
×
181
    }
×
182

183
    // Determine BLAS precision
184
    blas::BLAS_Precision precision;
5✔
185
    types::Scalar base_type(this->primitive_type(dfg));
5✔
186
    switch (base_type.primitive_type()) {
5✔
187
        case types::PrimitiveType::Half:
×
188
            precision = blas::BLAS_Precision::h;
×
189
            break;
×
190
        case types::PrimitiveType::Float:
5✔
191
            precision = blas::BLAS_Precision::s;
5✔
192
            break;
5✔
193
        case types::PrimitiveType::Double:
×
194
            precision = blas::BLAS_Precision::d;
×
195
            break;
×
196
        default:
×
197
            return false;
×
198
    }
5✔
199

200
    // Create new sequence for expansion
201
    auto& new_sequence = builder.add_sequence_before(
5✔
202
        *block_parent, *block, block_parent->at(block_index).second.assignments(), block->debug_info()
5✔
203
    );
5✔
204

205
    // Dimensions, i.e., 1D, 2D, 3D, ...
206
    size_t dims = this->kernel_shape_.size();
5✔
207
    symbolic::MultiExpression out_shape;
5✔
208
    out_shape.reserve(dims);
5✔
209
    // out_shape[i] = (shape[i + 2] + pads[i] + pads[dims + i] - dilations[i] * (kernel_shape[i] - 1) - 1)
210
    //                 / strides[i] + 1
211
    for (size_t i = 0; i < dims; i++) {
15✔
212
        out_shape.push_back(symbolic::add(
10✔
213
            symbolic::div(
10✔
214
                symbolic::sub(
10✔
215
                    symbolic::
10✔
216
                        sub(symbolic::add(this->shape_[i + 2], symbolic::add(this->pads_[i], this->pads_[dims + i])),
10✔
217
                            symbolic::mul(this->dilations_[i], symbolic::sub(this->kernel_shape_[i], symbolic::one()))),
10✔
218
                    symbolic::one()
10✔
219
                ),
10✔
220
                this->strides_[i]
10✔
221
            ),
10✔
222
            symbolic::one()
10✔
223
        ));
10✔
224
    }
10✔
225
    types::Scalar indvar_type(types::PrimitiveType::Int64);
5✔
226

227
    // If there are no groups (i.e., group == 1), then we can do im2row with one GEMM.
228
    // Else, we do naïve im2col with multiple GEMM's.
229
    if (symbolic::eq(this->group_, symbolic::one())) {
5✔
230
        /* ===== No groups ====================================================================== */
231

232
        // Add patches container with malloc
233
        symbolic::Expression patches_size = symbolic::mul(this->shape_[0], this->shape_[1]);
5✔
234
        for (size_t i = 0; i < dims; i++) {
15✔
235
            patches_size = symbolic::mul(patches_size, symbolic::mul(this->kernel_shape_[i], out_shape[i]));
10✔
236
        }
10✔
237
        types::Pointer patches_type(base_type);
5✔
238
        auto patches_container = builder.find_new_name("_patches");
5✔
239
        builder.add_container(patches_container, patches_type);
5✔
240
        auto& patches_malloc_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
241
        {
5✔
242
            auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
5✔
243
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
5✔
244
                patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
5✔
245
            );
5✔
246
            builder.add_computational_memlet(
5✔
247
                patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
5✔
248
            );
5✔
249
        }
5✔
250

251
        // Add malloc for temporary GEMM output
252
        symbolic::Expression tmp_Y_size = symbolic::mul(this->output_channels_, this->shape_[0]);
5✔
253
        for (size_t i = 0; i < dims; i++) {
15✔
254
            tmp_Y_size = symbolic::mul(tmp_Y_size, out_shape[i]);
10✔
255
        }
10✔
256
        auto tmp_Y_container = builder.find_new_name("_tmp_Y");
5✔
257
        types::Scalar tmp_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
5✔
258
        types::Pointer tmp_Y_type(tmp_Y_base_type);
5✔
259
        builder.add_container(tmp_Y_container, tmp_Y_type);
5✔
260
        auto& tmp_Y_malloc_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
261
        {
5✔
262
            auto& tmp_Y_access = builder.add_access(tmp_Y_malloc_block, tmp_Y_container, this->debug_info());
5✔
263
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
5✔
264
                tmp_Y_malloc_block,
5✔
265
                this->debug_info(),
5✔
266
                symbolic::mul(tmp_Y_size, symbolic::size_of_type(tmp_Y_base_type))
5✔
267
            );
5✔
268
            builder.add_computational_memlet(
5✔
269
                tmp_Y_malloc_block, libnode, "_ret", tmp_Y_access, {}, tmp_Y_type, this->debug_info()
5✔
270
            );
5✔
271
        }
5✔
272

273
        // Add loop over batch size
274
        auto n_container = builder.find_new_name("_n");
5✔
275
        builder.add_container(n_container, indvar_type);
5✔
276
        auto n = symbolic::symbol(n_container);
5✔
277
        auto& loop_n = builder.add_map(
5✔
278
            new_sequence,
5✔
279
            n,
5✔
280
            symbolic::Lt(n, this->shape_[0]),
5✔
281
            symbolic::zero(),
5✔
282
            symbolic::add(n, symbolic::one()),
5✔
283
            ScheduleType_Sequential::create(),
5✔
284
            {},
5✔
285
            block->debug_info()
5✔
286
        );
5✔
287
        structured_control_flow::Sequence* current_seq = &loop_n.root();
5✔
288

289
        // Add loops over output dimensions
290
        symbolic::SymbolVec os;
5✔
291
        os.reserve(dims);
5✔
292
        for (size_t i = 0; i < dims; i++) {
15✔
293
            auto o_container = builder.find_new_name("_o");
10✔
294
            builder.add_container(o_container, indvar_type);
10✔
295
            auto o = symbolic::symbol(o_container);
10✔
296
            os.push_back(o);
10✔
297
            auto& loop_o = builder.add_map(
10✔
298
                *current_seq,
10✔
299
                o,
10✔
300
                symbolic::Lt(o, out_shape[i]),
10✔
301
                symbolic::zero(),
10✔
302
                symbolic::add(o, symbolic::one()),
10✔
303
                ScheduleType_Sequential::create(),
10✔
304
                {},
10✔
305
                block->debug_info()
10✔
306
            );
10✔
307
            current_seq = &loop_o.root();
10✔
308
        }
10✔
309

310
        // Add loop over channels
311
        auto c_container = builder.find_new_name("_c");
5✔
312
        builder.add_container(c_container, indvar_type);
5✔
313
        auto c = symbolic::symbol(c_container);
5✔
314
        auto& loop_c = builder.add_map(
5✔
315
            *current_seq,
5✔
316
            c,
5✔
317
            symbolic::Lt(c, this->shape_[1]),
5✔
318
            symbolic::zero(),
5✔
319
            symbolic::add(c, symbolic::one()),
5✔
320
            ScheduleType_Sequential::create(),
5✔
321
            {},
5✔
322
            block->debug_info()
5✔
323
        );
5✔
324
        current_seq = &loop_c.root();
5✔
325

326
        // Add loops over kernel shape
327
        symbolic::SymbolVec ks;
5✔
328
        ks.reserve(dims);
5✔
329
        for (size_t i = 0; i < dims; i++) {
15✔
330
            auto k_container = builder.find_new_name("_k");
10✔
331
            builder.add_container(k_container, indvar_type);
10✔
332
            auto k = symbolic::symbol(k_container);
10✔
333
            ks.push_back(k);
10✔
334
            auto& loop_k = builder.add_map(
10✔
335
                *current_seq,
10✔
336
                k,
10✔
337
                symbolic::Lt(k, this->kernel_shape_[i]),
10✔
338
                symbolic::zero(),
10✔
339
                symbolic::add(k, symbolic::one()),
10✔
340
                ScheduleType_Sequential::create(),
10✔
341
                {},
10✔
342
                block->debug_info()
10✔
343
            );
10✔
344
            current_seq = &loop_k.root();
10✔
345
        }
10✔
346

347
        // Add if/else to stay in bounds for copying
348
        symbolic::MultiExpression is;
5✔
349
        is.reserve(dims);
5✔
350
        symbolic::Condition copy_condition = symbolic::__true__();
5✔
351
        symbolic::Condition zero_condition = symbolic::__false__();
5✔
352
        for (size_t i = 0; i < dims; i++) {
15✔
353
            auto i_expr = symbolic::
10✔
354
                add(symbolic::sub(symbolic::mul(os[i], this->strides_[i]), this->pads_[i]),
10✔
355
                    symbolic::mul(ks[i], this->dilations_[i]));
10✔
356
            is.push_back(i_expr);
10✔
357
            copy_condition = symbolic::
10✔
358
                And(copy_condition,
10✔
359
                    symbolic::And(symbolic::Lt(i_expr, this->shape_[i + 2]), symbolic::Ge(i_expr, symbolic::zero())));
10✔
360
            zero_condition = symbolic::
10✔
361
                Or(zero_condition,
10✔
362
                   symbolic::Or(symbolic::Ge(i_expr, this->shape_[i + 2]), symbolic::Lt(i_expr, symbolic::zero())));
10✔
363
        }
10✔
364
        auto& branch = builder.add_if_else(*current_seq, {}, block->debug_info());
5✔
365
        auto& copy_case = builder.add_case(branch, copy_condition, block->debug_info());
5✔
366
        auto& zero_case = builder.add_case(branch, zero_condition, block->debug_info());
5✔
367

368
        // Determine patches subset & tensor type
369
        data_flow::Subset patches_subset;
5✔
370
        patches_subset.push_back(n);
5✔
371
        patches_subset.insert(patches_subset.end(), os.begin(), os.end());
5✔
372
        patches_subset.push_back(c);
5✔
373
        patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
5✔
374
        symbolic::MultiExpression patches_shape;
5✔
375
        patches_shape.push_back(this->shape_[0]);
5✔
376
        patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
5✔
377
        patches_shape.push_back(this->shape_[1]);
5✔
378
        patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
5✔
379
        types::Tensor patches_tensor_type(base_type, patches_shape);
5✔
380

381
        // Determine subset for X
382
        data_flow::Subset subset_X;
5✔
383
        subset_X.push_back(n);
5✔
384
        subset_X.push_back(c);
5✔
385
        subset_X.insert(subset_X.end(), is.begin(), is.end());
5✔
386

387
        // Add copy from X to patches
388
        auto& copy_block = builder.add_block(copy_case, {}, block->debug_info());
5✔
389
        {
5✔
390
            auto& X_access = builder.add_access(copy_block, access_X->data(), access_X->debug_info());
5✔
391
            auto& patches_access = builder.add_access(copy_block, patches_container, this->debug_info());
5✔
392
            auto& tasklet =
5✔
393
                builder.add_tasklet(copy_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
394
            builder.add_computational_memlet(
5✔
395
                copy_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
5✔
396
            );
5✔
397
            builder.add_computational_memlet(
5✔
398
                copy_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
399
            );
5✔
400
        }
5✔
401

402
        // Add zero assignment to patches
403
        auto& zero_block = builder.add_block(zero_case, {}, block->debug_info());
5✔
404
        {
5✔
405
            auto& constant_zero = builder.add_constant(zero_block, "0.0", base_type, this->debug_info());
5✔
406
            auto& patches_access = builder.add_access(zero_block, patches_container, this->debug_info());
5✔
407
            auto& tasklet =
5✔
408
                builder.add_tasklet(zero_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
409
            builder
5✔
410
                .add_computational_memlet(zero_block, constant_zero, tasklet, "_in", {}, base_type, this->debug_info());
5✔
411
            builder.add_computational_memlet(
5✔
412
                zero_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
413
            );
5✔
414
        }
5✔
415

416
        // Add GEMM node
417
        auto& gemm_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
418
        {
5✔
419
            auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
5✔
420
            auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
5✔
421
            auto& W_access = builder.add_access(gemm_block, access_W->data(), access_W->debug_info());
5✔
422
            auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
5✔
423
            auto& tmp_Y_access_in = builder.add_access(gemm_block, tmp_Y_container, access_Y->debug_info());
5✔
424
            auto& tmp_Y_access_out = builder.add_access(gemm_block, tmp_Y_container, access_Y->debug_info());
5✔
425
            symbolic::Expression gemm_m = this->output_channels_;
5✔
426
            symbolic::Expression gemm_n = this->shape_[0];
5✔
427
            symbolic::Expression gemm_k = this->shape_[1];
5✔
428
            for (size_t i = 0; i < dims; i++) {
15✔
429
                gemm_n = symbolic::mul(gemm_n, out_shape[i]);
10✔
430
                gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
10✔
431
            }
10✔
432
            auto& libnode = builder.add_library_node<blas::GEMMNode>(
5✔
433
                gemm_block,
5✔
434
                this->debug_info(),
5✔
435
                blas::ImplementationType_BLAS,
5✔
436
                precision, // precision
5✔
437
                blas::BLAS_Layout::RowMajor, // layout
5✔
438
                blas::BLAS_Transpose::No, // transA
5✔
439
                blas::BLAS_Transpose::Trans, // transB
5✔
440
                gemm_m, // m
5✔
441
                gemm_n, // n
5✔
442
                gemm_k, // k
5✔
443
                gemm_k, // lda
5✔
444
                gemm_k, // ldb
5✔
445
                gemm_n // ldc
5✔
446
            );
5✔
447
            builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
5✔
448
            builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
5✔
449
            builder.add_computational_memlet(
5✔
450
                gemm_block,
5✔
451
                W_access,
5✔
452
                libnode,
5✔
453
                "__A",
5✔
454
                {},
5✔
455
                types::Pointer(types::Scalar(iedge_W->base_type().primitive_type())),
5✔
456
                iedge_W->debug_info()
5✔
457
            );
5✔
458
            builder.add_computational_memlet(
5✔
459
                gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info()
5✔
460
            );
5✔
461
            builder.add_computational_memlet(
5✔
462
                gemm_block, tmp_Y_access_in, libnode, "__C", {}, tmp_Y_type, oedge_Y->debug_info()
5✔
463
            );
5✔
464
            builder.add_computational_memlet(
5✔
465
                gemm_block, libnode, "__C", tmp_Y_access_out, {}, tmp_Y_type, oedge_Y->debug_info()
5✔
466
            );
5✔
467
        }
5✔
468

469
        // Add loop over batch size (again)
470
        auto& loop_n_2 = builder.add_map(
5✔
471
            new_sequence,
5✔
472
            n,
5✔
473
            symbolic::Lt(n, this->shape_[0]),
5✔
474
            symbolic::zero(),
5✔
475
            symbolic::add(n, symbolic::one()),
5✔
476
            ScheduleType_Sequential::create(),
5✔
477
            {},
5✔
478
            block->debug_info()
5✔
479
        );
5✔
480
        current_seq = &loop_n_2.root();
5✔
481

482
        // Add loop over output channels
483
        auto l_container = builder.find_new_name("_l");
5✔
484
        builder.add_container(l_container, indvar_type);
5✔
485
        auto l = symbolic::symbol(l_container);
5✔
486
        auto& loop_l = builder.add_map(
5✔
487
            *current_seq,
5✔
488
            l,
5✔
489
            symbolic::Lt(l, this->output_channels_),
5✔
490
            symbolic::zero(),
5✔
491
            symbolic::add(l, symbolic::one()),
5✔
492
            ScheduleType_Sequential::create(),
5✔
493
            {},
5✔
494
            block->debug_info()
5✔
495
        );
5✔
496
        current_seq = &loop_l.root();
5✔
497

498
        // Add loops over output dimensions (again)
499
        for (size_t i = 0; i < dims; i++) {
15✔
500
            auto o_container = builder.find_new_name("_o");
10✔
501
            builder.add_container(o_container, indvar_type);
10✔
502
            auto o = symbolic::symbol(o_container);
10✔
503
            auto& loop_o = builder.add_map(
10✔
504
                *current_seq,
10✔
505
                o,
10✔
506
                symbolic::Lt(o, out_shape[i]),
10✔
507
                symbolic::zero(),
10✔
508
                symbolic::add(o, symbolic::one()),
10✔
509
                ScheduleType_Sequential::create(),
10✔
510
                {},
10✔
511
                block->debug_info()
10✔
512
            );
10✔
513
            current_seq = &loop_o.root();
10✔
514
            os[i] = o;
10✔
515
        }
10✔
516

517
        // Add transposed copy from temporary GEMM output to Y + add bias if available
518
        data_flow::Subset tmp_Y_subset;
5✔
519
        tmp_Y_subset.push_back(l);
5✔
520
        tmp_Y_subset.push_back(n);
5✔
521
        tmp_Y_subset.insert(tmp_Y_subset.end(), os.begin(), os.end());
5✔
522
        symbolic::MultiExpression tmp_Y_shape;
5✔
523
        tmp_Y_shape.push_back(this->output_channels_);
5✔
524
        tmp_Y_shape.push_back(this->shape_[0]);
5✔
525
        tmp_Y_shape.insert(tmp_Y_shape.end(), out_shape.begin(), out_shape.end());
5✔
526
        types::Tensor tmp_Y_tensor_type(tmp_Y_base_type, tmp_Y_shape);
5✔
527
        data_flow::Subset Y_subset;
5✔
528
        Y_subset.push_back(n);
5✔
529
        Y_subset.push_back(l);
5✔
530
        Y_subset.insert(Y_subset.end(), os.begin(), os.end());
5✔
531
        auto& transpose_block = builder.add_block(*current_seq, {}, block->debug_info());
5✔
532
        if (has_bias) {
5✔
533
            auto& tmp_Y_access = builder.add_access(transpose_block, tmp_Y_container, this->debug_info());
×
UNCOV
534
            auto& B_access = builder.add_access(transpose_block, access_B->data(), access_B->debug_info());
×
UNCOV
535
            auto& Y_access = builder.add_access(transpose_block, access_Y->data(), access_Y->debug_info());
×
UNCOV
536
            auto& tasklet = builder.add_tasklet(
×
UNCOV
537
                transpose_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info()
×
UNCOV
538
            );
×
UNCOV
539
            builder.add_computational_memlet(
×
UNCOV
540
                transpose_block, tmp_Y_access, tasklet, "_in1", tmp_Y_subset, tmp_Y_tensor_type, this->debug_info()
×
UNCOV
541
            );
×
UNCOV
542
            builder.add_computational_memlet(
×
UNCOV
543
                transpose_block, B_access, tasklet, "_in2", {l}, iedge_B->base_type(), iedge_B->debug_info()
×
UNCOV
544
            );
×
UNCOV
545
            builder.add_computational_memlet(
×
UNCOV
546
                transpose_block, tasklet, "_out", Y_access, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
UNCOV
547
            );
×
548
        } else {
5✔
549
            auto& tmp_Y_access = builder.add_access(transpose_block, tmp_Y_container, this->debug_info());
5✔
550
            auto& Y_access = builder.add_access(transpose_block, access_Y->data(), access_Y->debug_info());
5✔
551
            auto& tasklet =
5✔
552
                builder
5✔
553
                    .add_tasklet(transpose_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
554
            builder.add_computational_memlet(
5✔
555
                transpose_block, tmp_Y_access, tasklet, "_in", tmp_Y_subset, tmp_Y_tensor_type, this->debug_info()
5✔
556
            );
5✔
557
            builder.add_computational_memlet(
5✔
558
                transpose_block, tasklet, "_out", Y_access, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
5✔
559
            );
5✔
560
        }
5✔
561

562
        // Add free for patches container
563
        auto& patches_free_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
564
        {
5✔
565
            auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
566
            auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
567
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
5✔
568
            builder.add_computational_memlet(
5✔
569
                patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
5✔
570
            );
5✔
571
            builder.add_computational_memlet(
5✔
572
                patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
5✔
573
            );
5✔
574
        }
5✔
575

576
        // Add free for temporary GEMM output
577
        auto& tmp_Y_free_block = builder.add_block(new_sequence, {}, block->debug_info());
5✔
578
        {
5✔
579
            auto& tmp_Y_access_in = builder.add_access(tmp_Y_free_block, tmp_Y_container, this->debug_info());
5✔
580
            auto& tmp_Y_access_out = builder.add_access(tmp_Y_free_block, tmp_Y_container, this->debug_info());
5✔
581
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(tmp_Y_free_block, this->debug_info());
5✔
582
            builder.add_computational_memlet(
5✔
583
                tmp_Y_free_block, tmp_Y_access_in, libnode, "_ptr", {}, tmp_Y_type, this->debug_info()
5✔
584
            );
5✔
585
            builder.add_computational_memlet(
5✔
586
                tmp_Y_free_block, libnode, "_ptr", tmp_Y_access_out, {}, tmp_Y_type, this->debug_info()
5✔
587
            );
5✔
588
        }
5✔
589

590
        /* ===== No groups ====================================================================== */
591

592
    } else {
5✔
593
        /* ===== Groups ========================================================================= */
594

595
        auto in_channels = symbolic::div(this->shape_[1], this->group_);
×
596
        auto out_channels = symbolic::div(this->output_channels_, this->group_);
×
597

598
        // Add loop over batch size
UNCOV
599
        auto n_container = builder.find_new_name("_n");
×
600
        builder.add_container(n_container, indvar_type);
×
601
        auto n = symbolic::symbol(n_container);
×
602
        auto& loop_n = builder.add_map(
×
603
            new_sequence,
×
604
            n,
×
605
            symbolic::Lt(n, this->shape_[0]),
×
606
            symbolic::zero(),
×
607
            symbolic::add(n, symbolic::one()),
×
608
            ScheduleType_Sequential::create(),
×
609
            {},
×
610
            block->debug_info()
×
611
        );
×
612

613
        // Add loop over groups
UNCOV
614
        auto g_container = builder.find_new_name("_g");
×
615
        builder.add_container(g_container, indvar_type);
×
616
        auto g = symbolic::symbol(g_container);
×
617
        auto& loop_g = builder.add_map(
×
618
            loop_n.root(),
×
619
            g,
×
620
            symbolic::Lt(g, this->group_),
×
621
            symbolic::zero(),
×
622
            symbolic::add(g, symbolic::one()),
×
623
            ScheduleType_Sequential::create(),
×
624
            {},
×
625
            block->debug_info()
×
626
        );
×
627

628
        // Add patches container with malloc
629
        symbolic::Expression patches_size = in_channels;
×
630
        for (size_t i = 0; i < dims; i++) {
×
631
            patches_size = symbolic::mul(patches_size, symbolic::mul(this->kernel_shape_[i], out_shape[i]));
×
UNCOV
632
        }
×
UNCOV
633
        types::Pointer patches_type(base_type);
×
634
        auto patches_container = builder.find_new_name("_patches");
×
635
        builder.add_container(patches_container, patches_type);
×
636
        auto& patches_malloc_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
637
        {
×
638
            auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
×
639
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
×
640
                patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
×
641
            );
×
642
            builder.add_computational_memlet(
×
643
                patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
×
644
            );
×
645
        }
×
646

647
        // Add loops over output dimensions
UNCOV
648
        structured_control_flow::Sequence* current_seq = &loop_g.root();
×
649
        symbolic::SymbolVec os;
×
650
        os.reserve(dims);
×
651
        for (size_t i = 0; i < dims; i++) {
×
652
            auto o_container = builder.find_new_name("_o");
×
653
            builder.add_container(o_container, indvar_type);
×
654
            auto o = symbolic::symbol(o_container);
×
655
            os.push_back(o);
×
656
            auto& loop_o = builder.add_map(
×
657
                *current_seq,
×
658
                o,
×
659
                symbolic::Lt(o, out_shape[i]),
×
660
                symbolic::zero(),
×
661
                symbolic::add(o, symbolic::one()),
×
662
                ScheduleType_Sequential::create(),
×
663
                {},
×
664
                block->debug_info()
×
665
            );
×
666
            current_seq = &loop_o.root();
×
667
        }
×
668

669
        // Add loop over channels
UNCOV
670
        auto c_container = builder.find_new_name("_c");
×
671
        builder.add_container(c_container, indvar_type);
×
672
        auto c = symbolic::symbol(c_container);
×
673
        auto& loop_c = builder.add_map(
×
674
            *current_seq,
×
675
            c,
×
676
            symbolic::Lt(c, in_channels),
×
677
            symbolic::zero(),
×
678
            symbolic::add(c, symbolic::one()),
×
679
            ScheduleType_Sequential::create(),
×
680
            {},
×
681
            block->debug_info()
×
682
        );
×
683
        current_seq = &loop_c.root();
×
684

685
        // Add loops over kernel shape
UNCOV
686
        symbolic::SymbolVec ks;
×
687
        ks.reserve(dims);
×
688
        for (size_t i = 0; i < dims; i++) {
×
689
            auto k_container = builder.find_new_name("_k");
×
690
            builder.add_container(k_container, indvar_type);
×
691
            auto k = symbolic::symbol(k_container);
×
692
            ks.push_back(k);
×
693
            auto& loop_k = builder.add_map(
×
694
                *current_seq,
×
695
                k,
×
696
                symbolic::Lt(k, this->kernel_shape_[i]),
×
697
                symbolic::zero(),
×
698
                symbolic::add(k, symbolic::one()),
×
699
                ScheduleType_Sequential::create(),
×
700
                {},
×
701
                block->debug_info()
×
702
            );
×
703
            current_seq = &loop_k.root();
×
704
        }
×
705

706
        // Add if/else to stay in bounds for copying
707
        symbolic::MultiExpression is;
×
708
        is.reserve(dims);
×
709
        symbolic::Condition copy_condition = symbolic::__true__();
×
710
        symbolic::Condition zero_condition = symbolic::__false__();
×
711
        for (size_t i = 0; i < dims; i++) {
×
UNCOV
712
            auto i_expr = symbolic::
×
UNCOV
713
                add(symbolic::sub(symbolic::mul(os[i], this->strides_[i]), this->pads_[i]),
×
714
                    symbolic::mul(ks[i], this->dilations_[i]));
×
715
            is.push_back(i_expr);
×
716
            copy_condition = symbolic::
×
717
                And(copy_condition,
×
718
                    symbolic::And(symbolic::Lt(i_expr, this->shape_[i + 2]), symbolic::Ge(i_expr, symbolic::zero())));
×
719
            zero_condition = symbolic::
×
720
                Or(zero_condition,
×
721
                   symbolic::Or(symbolic::Ge(i_expr, this->shape_[i + 2]), symbolic::Lt(i_expr, symbolic::zero())));
×
722
        }
×
UNCOV
723
        auto& branch = builder.add_if_else(*current_seq, {}, block->debug_info());
×
UNCOV
724
        auto& copy_case = builder.add_case(branch, copy_condition, block->debug_info());
×
725
        auto& zero_case = builder.add_case(branch, zero_condition, block->debug_info());
×
726

727
        // Determine patches subset & tensor type
728
        data_flow::Subset patches_subset;
×
UNCOV
729
        patches_subset.push_back(c);
×
UNCOV
730
        patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
×
731
        patches_subset.insert(patches_subset.end(), os.begin(), os.end());
×
732
        symbolic::MultiExpression patches_shape;
×
733
        patches_shape.push_back(in_channels);
×
734
        patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
×
735
        patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
×
736
        types::Tensor patches_tensor_type(base_type, patches_shape);
×
737

738
        // Determine subset for X
739
        data_flow::Subset subset_X;
×
740
        subset_X.push_back(n);
×
741
        subset_X.push_back(symbolic::add(symbolic::mul(in_channels, g), c));
×
742
        subset_X.insert(subset_X.end(), is.begin(), is.end());
×
743

744
        // Add copy from X to patches
UNCOV
745
        auto& copy_block = builder.add_block(copy_case, {}, block->debug_info());
×
746
        {
×
747
            auto& X_access = builder.add_access(copy_block, access_X->data(), access_X->debug_info());
×
748
            auto& patches_access = builder.add_access(copy_block, patches_container, this->debug_info());
×
749
            auto& tasklet =
×
750
                builder.add_tasklet(copy_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
×
751
            builder.add_computational_memlet(
×
752
                copy_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
×
753
            );
×
754
            builder.add_computational_memlet(
×
755
                copy_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
×
756
            );
×
757
        }
×
758

759
        // Add zero assignment to patches
UNCOV
760
        auto& zero_block = builder.add_block(zero_case, {}, block->debug_info());
×
UNCOV
761
        {
×
762
            auto& constant_zero = builder.add_constant(zero_block, "0.0", base_type, this->debug_info());
×
763
            auto& patches_access = builder.add_access(zero_block, patches_container, this->debug_info());
×
764
            auto& tasklet =
×
765
                builder.add_tasklet(zero_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
×
766
            builder
×
767
                .add_computational_memlet(zero_block, constant_zero, tasklet, "_in", {}, base_type, this->debug_info());
×
768
            builder.add_computational_memlet(
×
769
                zero_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
×
770
            );
×
771
        }
×
772

773
        // Add reference to W
774
        auto ref_W_container = builder.find_new_name("_ref_W");
×
775
        types::Scalar ref_W_base_type(builder.subject().type(access_W->data()).primitive_type());
×
UNCOV
776
        types::Pointer ref_W_type(ref_W_base_type);
×
UNCOV
777
        builder.add_container(ref_W_container, ref_W_type);
×
778
        auto ref_W_subset = symbolic::mul(symbolic::mul(out_channels, g), in_channels);
×
779
        for (size_t i = 0; i < dims; i++) {
×
780
            ref_W_subset = symbolic::mul(ref_W_subset, this->kernel_shape_[i]);
×
781
        }
×
782
        auto& ref_W_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
783
        {
×
784
            auto& W_access = builder.add_access(ref_W_block, access_W->data(), access_W->debug_info());
×
785
            auto& ref_W_access = builder.add_access(ref_W_block, ref_W_container, access_W->debug_info());
×
786
            builder.add_reference_memlet(ref_W_block, W_access, ref_W_access, {ref_W_subset}, ref_W_type);
×
787
        }
×
788

789
        // Add reference to Y
790
        auto ref_Y_container = builder.find_new_name("_ref_Y");
×
791
        types::Scalar ref_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
×
792
        types::Pointer ref_Y_type(ref_Y_base_type);
×
793
        builder.add_container(ref_Y_container, ref_Y_type);
×
794
        auto ref_Y_subset = symbolic::add(symbolic::mul(this->output_channels_, n), symbolic::mul(out_channels, g));
×
795
        for (size_t i = 0; i < dims; i++) {
×
796
            ref_Y_subset = symbolic::mul(ref_Y_subset, out_shape[i]);
×
797
        }
×
798
        auto& ref_Y_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
799
        {
×
800
            auto& Y_access = builder.add_access(ref_Y_block, access_Y->data(), access_Y->debug_info());
×
801
            auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, access_Y->debug_info());
×
802
            builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_subset}, ref_Y_type);
×
803
        }
×
804

805
        // Add GEMM node
806
        auto& gemm_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
807
        {
×
808
            auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
×
809
            auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
×
810
            auto& ref_W_access = builder.add_access(gemm_block, ref_W_container, access_W->debug_info());
×
811
            auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
×
812
            auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
×
813
            auto& ref_Y_access_out = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
×
814
            symbolic::Expression gemm_m = out_channels;
×
815
            symbolic::Expression gemm_n = symbolic::one();
×
816
            symbolic::Expression gemm_k = in_channels;
×
817
            for (size_t i = 0; i < dims; i++) {
×
818
                gemm_n = symbolic::mul(gemm_n, out_shape[i]);
×
819
                gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
×
820
            }
×
821
            auto& libnode = builder.add_library_node<blas::GEMMNode>(
×
UNCOV
822
                gemm_block,
×
UNCOV
823
                this->debug_info(),
×
824
                blas::ImplementationType_BLAS,
×
UNCOV
825
                precision, // precision
×
826
                blas::BLAS_Layout::RowMajor, // layout
×
827
                blas::BLAS_Transpose::No, // transA
×
828
                blas::BLAS_Transpose::No, // transB
×
829
                gemm_m, // m
×
830
                gemm_n, // n
×
831
                gemm_k, // k
×
832
                gemm_k, // lda
×
833
                gemm_n, // ldb
×
834
                gemm_n // ldc
×
835
            );
×
836
            builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
×
837
            builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
×
838
            builder
×
839
                .add_computational_memlet(gemm_block, ref_W_access, libnode, "__A", {}, ref_W_type, iedge_W->debug_info());
×
UNCOV
840
            builder.add_computational_memlet(
×
UNCOV
841
                gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info()
×
842
            );
×
843
            builder.add_computational_memlet(
×
844
                gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, oedge_Y->debug_info()
×
845
            );
×
846
            builder.add_computational_memlet(
×
847
                gemm_block, libnode, "__C", ref_Y_access_out, {}, ref_Y_type, oedge_Y->debug_info()
×
848
            );
×
849
        }
×
850

851
        // Add bias if available
852
        if (has_bias) {
×
853
            // Add loop over output channels
854
            auto l_container = builder.find_new_name("_l");
×
855
            builder.add_container(l_container, indvar_type);
×
856
            auto l = symbolic::symbol(l_container);
×
857
            auto& loop_l = builder.add_map(
×
858
                loop_g.root(),
×
UNCOV
859
                l,
×
UNCOV
860
                symbolic::Lt(l, out_channels),
×
861
                symbolic::zero(),
×
862
                symbolic::add(l, symbolic::one()),
×
863
                ScheduleType_Sequential::create(),
×
864
                {},
×
865
                block->debug_info()
×
866
            );
×
867
            current_seq = &loop_l.root();
×
868

869
            // Add loops over output dimensions (again)
870
            for (size_t i = 0; i < dims; i++) {
×
871
                auto o_container = builder.find_new_name("_o");
×
872
                builder.add_container(o_container, indvar_type);
×
873
                auto o = symbolic::symbol(o_container);
×
874
                auto& loop_o = builder.add_map(
×
875
                    *current_seq,
×
876
                    o,
×
877
                    symbolic::Lt(o, out_shape[i]),
×
878
                    symbolic::zero(),
×
879
                    symbolic::add(o, symbolic::one()),
×
880
                    ScheduleType_Sequential::create(),
×
881
                    {},
×
882
                    block->debug_info()
×
883
                );
×
884
                current_seq = &loop_o.root();
×
UNCOV
885
                os[i] = o;
×
UNCOV
886
            }
×
887

888
            // Add bias to Y
889
            data_flow::Subset Y_subset;
×
890
            Y_subset.push_back(n);
×
891
            Y_subset.push_back(symbolic::add(symbolic::mul(out_channels, g), l));
×
892
            Y_subset.insert(Y_subset.end(), os.begin(), os.end());
×
893
            auto B_subset = symbolic::add(symbolic::mul(out_channels, g), l);
×
894
            auto& bias_block = builder.add_block(*current_seq, {}, block->debug_info());
×
895
            {
×
896
                auto& B_access = builder.add_access(bias_block, access_B->data(), access_B->debug_info());
×
897
                auto& Y_access_in = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
898
                auto& Y_access_out = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
UNCOV
899
                auto& tasklet = builder.add_tasklet(
×
UNCOV
900
                    bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info()
×
901
                );
×
UNCOV
902
                builder.add_computational_memlet(
×
UNCOV
903
                    bias_block, Y_access_in, tasklet, "_in1", Y_subset, oedge_Y->base_type(), this->debug_info()
×
UNCOV
904
                );
×
UNCOV
905
                builder.add_computational_memlet(
×
UNCOV
906
                    bias_block, B_access, tasklet, "_in2", {B_subset}, iedge_B->base_type(), iedge_B->debug_info()
×
907
                );
×
908
                builder.add_computational_memlet(
×
UNCOV
909
                    bias_block, tasklet, "_out", Y_access_out, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
UNCOV
910
                );
×
UNCOV
911
            }
×
UNCOV
912
        }
×
913

914
        // Add free for patches container
UNCOV
915
        auto& patches_free_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
UNCOV
916
        {
×
UNCOV
917
            auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
×
UNCOV
918
            auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
×
UNCOV
919
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
×
UNCOV
920
            builder.add_computational_memlet(
×
UNCOV
921
                patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
×
922
            );
×
923
            builder.add_computational_memlet(
×
UNCOV
924
                patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
×
925
            );
×
926
        }
×
927

928
        /* ===== Groups ========================================================================= */
929
    }
×
930

931
    // Clean up the original block
932
    builder.remove_memlet(*block, *iedge_X);
5✔
933
    builder.remove_memlet(*block, *iedge_W);
5✔
934
    if (has_bias) {
5✔
935
        builder.remove_memlet(*block, *iedge_B);
×
936
    }
×
937
    builder.remove_memlet(*block, *oedge_Y);
5✔
938
    builder.remove_node(*block, *access_X);
5✔
939
    builder.remove_node(*block, *access_W);
5✔
940
    if (has_bias) {
5✔
941
        builder.remove_node(*block, *access_B);
×
942
    }
×
943
    builder.remove_node(*block, *access_Y);
5✔
944
    builder.remove_node(*block, *this);
5✔
945
    builder.remove_child(*block_parent, block_index + 1);
5✔
946

947
    return true;
5✔
948
}
5✔
949

950
symbolic::SymbolSet ConvNode::symbols() const {
×
951
    symbolic::SymbolSet syms;
×
952

953
    for (auto& expr : shape_) {
×
954
        for (auto& atom : symbolic::atoms(expr)) {
×
955
            syms.insert(atom);
×
UNCOV
956
        }
×
957
    }
×
958
    for (auto& expr : kernel_shape_) {
×
UNCOV
959
        for (auto& atom : symbolic::atoms(expr)) {
×
960
            syms.insert(atom);
×
961
        }
×
962
    }
×
963
    for (auto& expr : strides_) {
×
964
        for (auto& atom : symbolic::atoms(expr)) {
×
965
            syms.insert(atom);
×
966
        }
×
967
    }
×
968
    for (auto& expr : pads_) {
×
969
        for (auto& atom : symbolic::atoms(expr)) {
×
970
            syms.insert(atom);
×
971
        }
×
972
    }
×
973
    for (auto& expr : dilations_) {
×
974
        for (auto& atom : symbolic::atoms(expr)) {
×
975
            syms.insert(atom);
×
976
        }
×
977
    }
×
978
    for (auto& atom : symbolic::atoms(output_channels_)) {
×
UNCOV
979
        syms.insert(atom);
×
UNCOV
980
    }
×
UNCOV
981
    for (auto& atom : symbolic::atoms(group_)) {
×
UNCOV
982
        syms.insert(atom);
×
UNCOV
983
    }
×
984

UNCOV
985
    return syms;
×
UNCOV
986
}
×
987

UNCOV
988
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
UNCOV
989
    for (auto& expr : shape_) {
×
UNCOV
990
        expr = symbolic::subs(expr, old_expression, new_expression);
×
UNCOV
991
    }
×
UNCOV
992
    for (auto& expr : kernel_shape_) {
×
UNCOV
993
        expr = symbolic::subs(expr, old_expression, new_expression);
×
UNCOV
994
    }
×
UNCOV
995
    for (auto& expr : strides_) {
×
UNCOV
996
        expr = symbolic::subs(expr, old_expression, new_expression);
×
997
    }
×
998
    for (auto& expr : pads_) {
×
999
        expr = symbolic::subs(expr, old_expression, new_expression);
×
1000
    }
×
1001
    for (auto& expr : dilations_) {
×
1002
        expr = symbolic::subs(expr, old_expression, new_expression);
×
1003
    }
×
1004
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
1005
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
1006
}
×
1007

1008
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
1009
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
1010
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
1011
        element_id,
1✔
1012
        this->debug_info(),
1✔
1013
        vertex,
1✔
1014
        parent,
1✔
1015
        shape_,
1✔
1016
        kernel_shape_,
1✔
1017
        strides_,
1✔
1018
        pads_,
1✔
1019
        dilations_,
1✔
1020
        output_channels_,
1✔
1021
        group_
1✔
1022
    ));
1✔
1023
}
1✔
1024

1025
std::string ConvNode::toStr() const {
×
1026
    std::stringstream result;
×
1027
    result << "Conv(shape=[";
×
1028
    for (size_t i = 0; i < shape_.size(); ++i) {
×
1029
        if (i > 0) {
×
1030
            result << ", ";
×
1031
        }
×
1032
        result << shape_[i]->__str__();
×
1033
    }
×
1034
    result << "], kernel_shape=[";
×
1035
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
1036
        if (i > 0) {
×
1037
            result << ", ";
×
UNCOV
1038
        }
×
1039
        result << kernel_shape_[i]->__str__();
×
1040
    }
×
1041
    result << "], strides=[";
×
UNCOV
1042
    for (size_t i = 0; i < strides_.size(); ++i) {
×
1043
        if (i > 0) {
×
UNCOV
1044
            result << ", ";
×
1045
        }
×
UNCOV
1046
        result << strides_[i]->__str__();
×
1047
    }
×
1048
    result << "], pads=[";
×
1049
    for (size_t i = 0; i < pads_.size(); ++i) {
×
1050
        if (i > 0) {
×
UNCOV
1051
            result << ", ";
×
1052
        }
×
1053
        result << pads_[i]->__str__();
×
1054
    }
×
1055
    result << "], dilations=[";
×
UNCOV
1056
    for (size_t i = 0; i < dilations_.size(); ++i) {
×
1057
        if (i > 0) {
×
1058
            result << ", ";
×
1059
        }
×
1060
        result << dilations_[i]->__str__();
×
UNCOV
1061
    }
×
1062
    result << "], output_channels=" + output_channels_->__str__();
×
1063
    result << ", group=" + group_->__str__() + ")";
×
1064
    return result.str();
×
1065
}
×
1066

1067
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
1068
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
1069
    nlohmann::json j;
×
1070

UNCOV
1071
    j["code"] = conv_node.code().value();
×
1072

1073
    serializer::JSONSerializer serializer;
×
1074

1075
    j["shape"] = nlohmann::json::array();
×
1076
    for (auto& dim : conv_node.shape()) {
×
UNCOV
1077
        j["shape"].push_back(serializer.expression(dim));
×
UNCOV
1078
    }
×
1079

1080
    j["kernel_shape"] = nlohmann::json::array();
×
1081
    for (auto& dim : conv_node.kernel_shape()) {
×
1082
        j["kernel_shape"].push_back(serializer.expression(dim));
×
1083
    }
×
1084

UNCOV
1085
    j["strides"] = nlohmann::json::array();
×
1086
    for (auto& stride : conv_node.strides()) {
×
1087
        j["strides"].push_back(serializer.expression(stride));
×
1088
    }
×
1089

1090
    j["pads"] = nlohmann::json::array();
×
1091
    for (auto& pad : conv_node.pads()) {
×
UNCOV
1092
        j["pads"].push_back(serializer.expression(pad));
×
1093
    }
×
1094

1095
    j["dilations"] = nlohmann::json::array();
×
1096
    for (auto& dilation : conv_node.dilations()) {
×
UNCOV
1097
        j["dilations"].push_back(serializer.expression(dilation));
×
1098
    }
×
1099

1100
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
1101
    j["group"] = serializer.expression(conv_node.group());
×
1102

1103
    return j;
×
UNCOV
1104
}
×
1105

1106
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
1107
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
1108
) {
×
1109
    assert(j.contains("element_id"));
×
1110
    assert(j.contains("code"));
×
UNCOV
1111
    assert(j.contains("debug_info"));
×
1112
    assert(j.contains("kernel_shape"));
×
1113

1114
    std::vector<symbolic::Expression> shape;
×
1115
    if (j.contains("shape")) {
×
1116
        for (const auto& dim : j["shape"]) {
×
1117
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
UNCOV
1118
        }
×
1119
    }
×
1120

1121
    std::vector<symbolic::Expression> kernel_shape;
×
1122
    for (const auto& dim : j["kernel_shape"]) {
×
UNCOV
1123
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
1124
    }
×
1125

1126
    std::vector<symbolic::Expression> strides;
×
1127
    if (j.contains("strides")) {
×
UNCOV
1128
        for (const auto& stride : j["strides"]) {
×
1129
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
1130
        }
×
UNCOV
1131
    }
×
1132

1133
    std::vector<symbolic::Expression> pads;
×
1134
    if (j.contains("pads")) {
×
UNCOV
1135
        for (const auto& pad : j["pads"]) {
×
UNCOV
1136
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
UNCOV
1137
        }
×
UNCOV
1138
    }
×
1139

UNCOV
1140
    std::vector<symbolic::Expression> dilations;
×
UNCOV
1141
    if (j.contains("dilations")) {
×
UNCOV
1142
        for (const auto& dilation : j["dilations"]) {
×
UNCOV
1143
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
UNCOV
1144
        }
×
UNCOV
1145
    }
×
1146

UNCOV
1147
    symbolic::Expression output_channels = symbolic::one();
×
UNCOV
1148
    if (j.contains("output_channels")) {
×
UNCOV
1149
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
UNCOV
1150
    }
×
1151

UNCOV
1152
    symbolic::Expression group = symbolic::one();
×
UNCOV
1153
    if (j.contains("group")) {
×
UNCOV
1154
        group = symbolic::parse(j["group"].get<std::string>());
×
UNCOV
1155
    }
×
1156

UNCOV
1157
    sdfg::serializer::JSONSerializer serializer;
×
UNCOV
1158
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
1159

UNCOV
1160
    return builder.add_library_node<
×
UNCOV
1161
        ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, output_channels, group);
×
UNCOV
1162
}
×
1163

1164
} // namespace tensor
1165
} // namespace math
1166
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc