• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 24159142213

08 Apr 2026 09:16PM UTC coverage: 64.86% (-0.1%) from 64.986%
24159142213

Pull #663

github

web-flow
Merge 906a00e14 into 223814883
Pull Request #663: expands conv node patch-wise

85 of 210 new or added lines in 1 file covered. (40.48%)

8 existing lines in 1 file now uncovered.

29143 of 44932 relevant lines covered (64.86%)

602.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

37.37
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4
#include <sstream>
5

6
#include "sdfg/analysis/analysis.h"
7
#include "sdfg/builder/structured_sdfg_builder.h"
8
#include "sdfg/data_flow/access_node.h"
9
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
10
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
11
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
12
#include "sdfg/data_flow/library_nodes/stdlib/memset.h"
13
#include "sdfg/data_flow/memlet.h"
14
#include "sdfg/data_flow/tasklet.h"
15
#include "sdfg/exceptions.h"
16
#include "sdfg/structured_control_flow/block.h"
17
#include "sdfg/structured_control_flow/map.h"
18
#include "sdfg/structured_control_flow/sequence.h"
19
#include "sdfg/symbolic/symbolic.h"
20
#include "sdfg/types/pointer.h"
21
#include "sdfg/types/scalar.h"
22
#include "sdfg/types/tensor.h"
23
#include "sdfg/types/type.h"
24

25
#include "sdfg/analysis/scope_analysis.h"
26
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
27
#include "symengine/integer.h"
28
#include "symengine/symengine_rcp.h"
29

30
namespace sdfg {
31
namespace math {
32
namespace tensor {
33

34
ConvNode::ConvNode(
35
    size_t element_id,
36
    const DebugInfo& debug_info,
37
    const graph::Vertex vertex,
38
    data_flow::DataFlowGraph& parent,
39
    const std::vector<symbolic::Expression>& shape,
40
    const std::vector<symbolic::Expression>& kernel_shape,
41
    const std::vector<symbolic::Expression>& strides,
42
    const std::vector<symbolic::Expression>& pads,
43
    const std::vector<symbolic::Expression>& dilations,
44
    symbolic::Expression output_channels,
45
    symbolic::Expression group
46
)
47
    : TensorNode(
22✔
48
          element_id,
22✔
49
          debug_info,
22✔
50
          vertex,
22✔
51
          parent,
22✔
52
          LibraryNodeType_Conv,
22✔
53
          {"Y"},
22✔
54
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
22✔
55
          data_flow::ImplementationType_NONE
22✔
56
      ),
22✔
57
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations),
22✔
58
      output_channels_(output_channels), group_(group) {}
22✔
59

60
void ConvNode::validate(const Function& function) const {
17✔
61
    TensorNode::validate(function);
17✔
62

63
    auto& graph = this->get_parent();
17✔
64

65
    // Custom validation for ConvNode that handles optional bias input
66
    // We expect X, W as required inputs and optionally B (bias)
67

68
    // Collect all input edges by connector name
69
    std::map<std::string, const data_flow::Memlet*> input_edges;
17✔
70
    for (auto& iedge : graph.in_edges(*this)) {
34✔
71
        input_edges[iedge.dst_conn()] = &iedge;
34✔
72
    }
34✔
73

74
    // Check that required inputs X and W are present
75
    if (input_edges.find("X") == input_edges.end()) {
17✔
76
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
77
    }
×
78
    if (input_edges.find("W") == input_edges.end()) {
17✔
79
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
80
    }
×
81

82
    // Validate that parameters are not empty
83
    if (shape_.empty()) {
17✔
84
        throw InvalidSDFGException("ConvNode shape cannot be empty");
×
85
    }
×
86
    if (kernel_shape_.empty()) {
17✔
87
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
88
    }
×
89
    if (strides_.empty()) {
17✔
90
        throw InvalidSDFGException("ConvNode strides cannot be empty");
×
91
    }
×
92
    if (pads_.empty()) {
17✔
93
        throw InvalidSDFGException("ConvNode pads cannot be empty");
×
94
    }
×
95
    if (dilations_.empty()) {
17✔
96
        throw InvalidSDFGException("ConvNode dilations cannot be empty");
×
97
    }
×
98

99
    // Validate consistent dimensions
100
    size_t spatial_dims = kernel_shape_.size();
17✔
101

102
    if (shape_.size() != spatial_dims + 2) {
17✔
103
        throw InvalidSDFGException("ConvNode shape must match kernel spatial dimensions + 2");
×
104
    }
×
105

106
    if (strides_.size() != spatial_dims) {
17✔
107
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
108
    }
1✔
109

110
    if (pads_.size() != 2 * spatial_dims) {
16✔
111
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
112
    }
1✔
113

114
    if (dilations_.size() != spatial_dims) {
15✔
115
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
116
    }
×
117

118
    // Validate groups
119
    if (SymEngine::is_a<SymEngine::Integer>(*this->group_)) {
15✔
120
        auto group_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->group_)->as_int();
15✔
121
        if (SymEngine::is_a<SymEngine::Integer>(*this->shape_[1])) {
15✔
122
            auto input_channels_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->shape_[1])->as_int();
15✔
123
            if (input_channels_int % group_int != 0) {
15✔
124
                throw InvalidSDFGException("ConvNode input channels must be divisible by groups");
×
125
            }
×
126
        }
15✔
127
        if (SymEngine::is_a<SymEngine::Integer>(*this->output_channels_)) {
15✔
128
            auto output_channels_int =
15✔
129
                SymEngine::rcp_static_cast<const SymEngine::Integer>(this->output_channels_)->as_int();
15✔
130
            if (output_channels_int % group_int != 0) {
15✔
131
                throw InvalidSDFGException("ConvNode output channels must be divisible by groups");
×
132
            }
×
133
        }
15✔
134
    }
15✔
135
}
15✔
136

137
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
138
    // Validate nodes are standalone in the data flow graph
139
    auto& dfg = this->get_parent();
5✔
140
    if ((dfg.nodes().size() != 4 || dfg.edges().size() != 3) && (dfg.nodes().size() != 5 || dfg.edges().size() != 4)) {
5✔
141
        return false;
×
142
    }
×
143

144
    // Get edges
145
    auto iedges = dfg.in_edges_by_connector(*this);
5✔
146
    auto oedges = dfg.out_edges_by_connector(*this);
5✔
147
    if (iedges.size() != 3 || oedges.size() != 1) {
5✔
148
        return false;
×
149
    }
×
150
    auto* iedge_X = iedges.at(0);
5✔
151
    auto* iedge_W = iedges.at(1);
5✔
152
    auto* iedge_B = iedges.at(2);
5✔
153
    auto* oedge_Y = oedges.at(0);
5✔
154
    if (!iedge_X || !iedge_W || !oedge_Y) {
5✔
155
        return false;
×
156
    }
×
157
    bool has_bias = iedge_B != nullptr;
5✔
158

159
    // Get access nodes
160
    auto* access_X = dynamic_cast<data_flow::AccessNode*>(&iedge_X->src());
5✔
161
    auto* access_W = dynamic_cast<data_flow::AccessNode*>(&iedge_W->src());
5✔
162
    auto* access_B = (has_bias ? dynamic_cast<data_flow::AccessNode*>(&iedge_B->src()) : nullptr);
5✔
163
    auto* access_Y = dynamic_cast<data_flow::AccessNode*>(&oedge_Y->dst());
5✔
164
    if (!access_X || !access_W || (has_bias && !access_B) || !access_Y) {
5✔
165
        return false;
×
166
    }
×
167

168
    // Get block & its parent
169
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
5✔
170
    if (!block) {
5✔
171
        return false;
×
172
    }
×
173
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
174
    auto* block_parent = dynamic_cast<structured_control_flow::Sequence*>(scope_analysis.parent_scope(block));
5✔
175
    if (!block_parent) {
5✔
176
        return false;
×
177
    }
×
178
    size_t block_index = block_parent->index(*block);
5✔
179
    if (block_index >= block_parent->size()) {
5✔
180
        return false;
×
181
    }
×
182

183
    // Determine BLAS precision
184
    blas::BLAS_Precision precision;
5✔
185
    types::Scalar base_type(this->primitive_type(dfg));
5✔
186
    switch (base_type.primitive_type()) {
5✔
187
        case types::PrimitiveType::Half:
×
188
            precision = blas::BLAS_Precision::h;
×
189
            break;
×
190
        case types::PrimitiveType::Float:
5✔
191
            precision = blas::BLAS_Precision::s;
5✔
192
            break;
5✔
193
        case types::PrimitiveType::Double:
×
194
            precision = blas::BLAS_Precision::d;
×
195
            break;
×
196
        default:
×
197
            return false;
×
198
    }
5✔
199

200
    // Create new sequence for expansion
201
    auto& new_sequence = builder.add_sequence_before(
5✔
202
        *block_parent, *block, block_parent->at(block_index).second.assignments(), block->debug_info()
5✔
203
    );
5✔
204

205
    // Dimensions, i.e., 1D, 2D, 3D, ...
206
    size_t dims = this->kernel_shape_.size();
5✔
207
    symbolic::MultiExpression out_shape;
5✔
208
    out_shape.reserve(dims);
5✔
209
    // out_shape[i] = (shape[i + 2] + pads[i] + pads[dims + i] - dilations[i] * (kernel_shape[i] - 1) - 1)
210
    //                 / strides[i] + 1
211
    for (size_t i = 0; i < dims; i++) {
15✔
212
        out_shape.push_back(symbolic::add(
10✔
213
            symbolic::div(
10✔
214
                symbolic::sub(
10✔
215
                    symbolic::
10✔
216
                        sub(symbolic::add(this->shape_[i + 2], symbolic::add(this->pads_[i], this->pads_[dims + i])),
10✔
217
                            symbolic::mul(this->dilations_[i], symbolic::sub(this->kernel_shape_[i], symbolic::one()))),
10✔
218
                    symbolic::one()
10✔
219
                ),
10✔
220
                this->strides_[i]
10✔
221
            ),
10✔
222
            symbolic::one()
10✔
223
        ));
10✔
224
    }
10✔
225
    types::Scalar indvar_type(types::PrimitiveType::Int64);
5✔
226

227
    // Compute spatial dimensions: K = C_in * prod(kernel_shape), spatial = prod(out_shape)
228
    // Patches layout: [K, spatial] = [C_in, kernel..., out_shape...]
229
    // GEMM: Y[F, spatial] = W[F, K] × patches[K, spatial]  (NoTrans × NoTrans)
230
    // Output Y is [N, F, out_shape...] — GEMM writes directly per batch via pointer reference
231

232
    if (symbolic::eq(this->group_, symbolic::one())) {
5✔
233
        /* ===== No groups ====================================================================== */
234

235
        // Compute K = C_in * prod(kernel_shape) and spatial = prod(out_shape)
236
        symbolic::Expression gemm_k = this->shape_[1]; // C_in
5✔
237
        for (size_t i = 0; i < dims; i++) {
15✔
238
            gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
10✔
239
        }
10✔
240
        symbolic::Expression spatial = symbolic::one();
5✔
241
        for (size_t i = 0; i < dims; i++) {
15✔
242
            spatial = symbolic::mul(spatial, out_shape[i]);
10✔
243
        }
10✔
244

245
        // Patches buffer: K × spatial (per-batch, reused across batches)
246
        symbolic::Expression patches_size = symbolic::mul(gemm_k, spatial);
5✔
247
        types::Pointer patches_type(base_type);
5✔
248
        auto patches_container = builder.find_new_name("_patches");
5✔
249
        builder.add_container(patches_container, patches_type);
5✔
250

251
        // Batch loop
252
        auto n_container = builder.find_new_name("_n");
5✔
253
        builder.add_container(n_container, indvar_type);
5✔
254
        auto n = symbolic::symbol(n_container);
5✔
255
        auto& loop_n = builder.add_map(
5✔
256
            new_sequence,
5✔
257
            n,
5✔
258
            symbolic::Lt(n, this->shape_[0]),
5✔
259
            symbolic::zero(),
5✔
260
            symbolic::add(n, symbolic::one()),
5✔
261
            ScheduleType_Sequential::create(),
5✔
262
            {},
5✔
263
            block->debug_info()
5✔
264
        );
5✔
265

266
        // Malloc patches
267
        auto& patches_malloc_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
268
        {
5✔
269
            auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
5✔
270
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
5✔
271
                patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
5✔
272
            );
5✔
273
            builder.add_computational_memlet(
5✔
274
                patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
5✔
275
            );
5✔
276
        }
5✔
277

278
        // Memset patches to zero (inside batch loop, per-batch)
279
        auto& patches_memset_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
280
        {
5✔
281
            auto& patches_access = builder.add_access(patches_memset_block, patches_container, this->debug_info());
5✔
282
            auto& libnode = builder.add_library_node<stdlib::MemsetNode>(
5✔
283
                patches_memset_block,
5✔
284
                this->debug_info(),
5✔
285
                symbolic::zero(),
5✔
286
                symbolic::mul(patches_size, symbolic::size_of_type(base_type))
5✔
287
            );
5✔
288
            builder.add_computational_memlet(
5✔
289
                patches_memset_block, libnode, "_ptr", patches_access, {}, patches_type, this->debug_info()
5✔
290
            );
5✔
291
        }
5✔
292

293
        // Im2col: nested loops over channels, kernel, output positions
294
        // Loop order: c, k[0..dims-1], o[0..dims-1]
295
        // Patches layout: [C_in, kernel..., out_shape...] — row-major gives [K, spatial]
296
        structured_control_flow::Sequence* current_seq = &loop_n.root();
5✔
297

298
        // Channel loop
299
        auto c_container = builder.find_new_name("_c");
5✔
300
        builder.add_container(c_container, indvar_type);
5✔
301
        auto c = symbolic::symbol(c_container);
5✔
302
        auto& loop_c = builder.add_map(
5✔
303
            *current_seq,
5✔
304
            c,
5✔
305
            symbolic::Lt(c, this->shape_[1]),
5✔
306
            symbolic::zero(),
5✔
307
            symbolic::add(c, symbolic::one()),
5✔
308
            ScheduleType_Sequential::create(),
5✔
309
            {},
5✔
310
            block->debug_info()
5✔
311
        );
5✔
312
        current_seq = &loop_c.root();
5✔
313

314
        // Kernel dimension loops
315
        symbolic::SymbolVec ks;
5✔
316
        ks.reserve(dims);
5✔
317
        symbolic::MultiExpression input_indices; // i_expr for each spatial dim
5✔
318
        input_indices.reserve(dims);
5✔
319
        for (size_t i = 0; i < dims; i++) {
15✔
320
            auto k_container = builder.find_new_name("_k");
10✔
321
            builder.add_container(k_container, indvar_type);
10✔
322
            auto k = symbolic::symbol(k_container);
10✔
323
            ks.push_back(k);
10✔
324
            auto& loop_k = builder.add_map(
10✔
325
                *current_seq,
10✔
326
                k,
10✔
327
                symbolic::Lt(k, this->kernel_shape_[i]),
10✔
328
                symbolic::zero(),
10✔
329
                symbolic::add(k, symbolic::one()),
10✔
330
                ScheduleType_Sequential::create(),
10✔
331
                {},
10✔
332
                block->debug_info()
10✔
333
            );
10✔
334
            current_seq = &loop_k.root();
10✔
335
        }
10✔
336

337
        // Output spatial dimension loops (with padding/dilation bounds check)
338
        symbolic::SymbolVec os;
5✔
339
        os.reserve(dims);
5✔
340
        for (size_t i = 0; i < dims; i++) {
15✔
341
            auto o_container = builder.find_new_name("_o");
10✔
342
            builder.add_container(o_container, indvar_type);
10✔
343
            auto o = symbolic::symbol(o_container);
10✔
344
            os.push_back(o);
10✔
345
            // i_expr = o * stride - pad + k * dilation
346
            auto i_expr = symbolic::
10✔
347
                add(symbolic::sub(symbolic::mul(o, this->strides_[i]), this->pads_[i]),
10✔
348
                    symbolic::mul(ks[i], this->dilations_[i]));
10✔
349
            input_indices.push_back(i_expr);
10✔
350
            auto& loop_o = builder.add_map(
10✔
351
                *current_seq,
10✔
352
                o,
10✔
353
                symbolic::And(symbolic::Lt(o, out_shape[i]), symbolic::Lt(i_expr, this->shape_[i + 2])),
10✔
354
                symbolic::zero(),
10✔
355
                symbolic::add(o, symbolic::one()),
10✔
356
                ScheduleType_Sequential::create(),
10✔
357
                {},
10✔
358
                block->debug_info()
10✔
359
            );
10✔
360
            current_seq = &loop_o.root();
10✔
361
        }
10✔
362

363
        // Patches subset: [c, k0, k1, ..., o0, o1, ...]
364
        // Patches shape: [C_in, kH, kW, ..., H_out, W_out, ...]
365
        data_flow::Subset patches_subset;
5✔
366
        patches_subset.push_back(c);
5✔
367
        patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
5✔
368
        patches_subset.insert(patches_subset.end(), os.begin(), os.end());
5✔
369
        symbolic::MultiExpression patches_shape;
5✔
370
        patches_shape.push_back(this->shape_[1]); // C_in
5✔
371
        patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
5✔
372
        patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
5✔
373
        types::Tensor patches_tensor_type(base_type, patches_shape);
5✔
374

375
        // X subset: [n, c, i0, i1, ...]
376
        data_flow::Subset subset_X;
5✔
377
        subset_X.push_back(n);
5✔
378
        subset_X.push_back(c);
5✔
379
        subset_X.insert(subset_X.end(), input_indices.begin(), input_indices.end());
5✔
380

381
        // Copy X[n, c, i...] → patches[c, k..., o...]
382
        auto& im2col_block = builder.add_block(*current_seq, {}, block->debug_info());
5✔
383
        {
5✔
384
            auto& X_access = builder.add_access(im2col_block, access_X->data(), access_X->debug_info());
5✔
385
            auto& patches_access = builder.add_access(im2col_block, patches_container, this->debug_info());
5✔
386
            auto& tasklet =
5✔
387
                builder.add_tasklet(im2col_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
388
            builder.add_computational_memlet(
5✔
389
                im2col_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
5✔
390
            );
5✔
391
            builder.add_computational_memlet(
5✔
392
                im2col_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
393
            );
5✔
394
        }
5✔
395

396
        // Reference to Y[n, :, :, ...] — offset = n * F * spatial
397
        auto ref_Y_container = builder.find_new_name("_ref_Y");
5✔
398
        types::Scalar ref_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
5✔
399
        types::Pointer ref_Y_type(ref_Y_base_type);
5✔
400
        builder.add_container(ref_Y_container, ref_Y_type);
5✔
401
        auto ref_Y_offset = symbolic::mul(n, symbolic::mul(this->output_channels_, spatial));
5✔
402
        auto& ref_Y_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
403
        {
5✔
404
            auto& Y_access = builder.add_access(ref_Y_block, access_Y->data(), access_Y->debug_info());
5✔
405
            auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, access_Y->debug_info());
5✔
406
            builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_offset}, ref_Y_type);
5✔
407
        }
5✔
408

409
        // GEMM: Y[n][F, spatial] = W[F, K] × patches[K, spatial]
410
        // NoTrans × NoTrans, lda=K, ldb=spatial, ldc=spatial
411
        auto& gemm_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
412
        {
5✔
413
            auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
5✔
414
            auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
5✔
415
            auto& W_access = builder.add_access(gemm_block, access_W->data(), access_W->debug_info());
5✔
416
            auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
5✔
417
            auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
5✔
418
            auto& ref_Y_access_out = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
5✔
419
            auto& libnode = builder.add_library_node<blas::GEMMNode>(
5✔
420
                gemm_block,
5✔
421
                this->debug_info(),
5✔
422
                blas::ImplementationType_BLAS,
5✔
423
                precision,
5✔
424
                blas::BLAS_Layout::RowMajor,
5✔
425
                blas::BLAS_Transpose::No, // transA
5✔
426
                blas::BLAS_Transpose::No, // transB
5✔
427
                this->output_channels_, // m = F
5✔
428
                spatial, // n = spatial
5✔
429
                gemm_k, // k = K
5✔
430
                gemm_k, // lda = K
5✔
431
                spatial, // ldb = spatial
5✔
432
                spatial // ldc = spatial
5✔
433
            );
5✔
434
            builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
5✔
435
            builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
5✔
436
            builder.add_computational_memlet(
5✔
437
                gemm_block,
5✔
438
                W_access,
5✔
439
                libnode,
5✔
440
                "__A",
5✔
441
                {},
5✔
442
                types::Pointer(types::Scalar(iedge_W->base_type().primitive_type())),
5✔
443
                iedge_W->debug_info()
5✔
444
            );
5✔
445
            builder.add_computational_memlet(
5✔
446
                gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info()
5✔
447
            );
5✔
448
            builder.add_computational_memlet(
5✔
449
                gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, oedge_Y->debug_info()
5✔
450
            );
5✔
451
            builder.add_computational_memlet(
5✔
452
                gemm_block, libnode, "__C", ref_Y_access_out, {}, ref_Y_type, oedge_Y->debug_info()
5✔
453
            );
5✔
454
        }
5✔
455

456
        // Add bias if available: Y[n, f, o...] += B[f]
457
        if (has_bias) {
5✔
NEW
458
            auto l_container = builder.find_new_name("_l");
×
NEW
459
            builder.add_container(l_container, indvar_type);
×
NEW
460
            auto l = symbolic::symbol(l_container);
×
NEW
461
            auto& loop_l = builder.add_map(
×
NEW
462
                loop_n.root(),
×
NEW
463
                l,
×
NEW
464
                symbolic::Lt(l, this->output_channels_),
×
UNCOV
465
                symbolic::zero(),
×
NEW
466
                symbolic::add(l, symbolic::one()),
×
UNCOV
467
                ScheduleType_Sequential::create(),
×
UNCOV
468
                {},
×
UNCOV
469
                block->debug_info()
×
UNCOV
470
            );
×
NEW
471
            structured_control_flow::Sequence* bias_seq = &loop_l.root();
×
472

NEW
473
            symbolic::SymbolVec bias_os;
×
NEW
474
            bias_os.reserve(dims);
×
NEW
475
            for (size_t i = 0; i < dims; i++) {
×
NEW
476
                auto o_container = builder.find_new_name("_o");
×
NEW
477
                builder.add_container(o_container, indvar_type);
×
NEW
478
                auto o = symbolic::symbol(o_container);
×
NEW
479
                bias_os.push_back(o);
×
NEW
480
                auto& loop_o = builder.add_map(
×
NEW
481
                    *bias_seq,
×
NEW
482
                    o,
×
NEW
483
                    symbolic::Lt(o, out_shape[i]),
×
NEW
484
                    symbolic::zero(),
×
NEW
485
                    symbolic::add(o, symbolic::one()),
×
NEW
486
                    ScheduleType_Sequential::create(),
×
NEW
487
                    {},
×
NEW
488
                    block->debug_info()
×
NEW
489
                );
×
NEW
490
                bias_seq = &loop_o.root();
×
NEW
491
            }
×
492

NEW
493
            data_flow::Subset Y_subset;
×
NEW
494
            Y_subset.push_back(n);
×
NEW
495
            Y_subset.push_back(l);
×
NEW
496
            Y_subset.insert(Y_subset.end(), bias_os.begin(), bias_os.end());
×
497

NEW
498
            auto& bias_block = builder.add_block(*bias_seq, {}, block->debug_info());
×
NEW
499
            {
×
NEW
500
                auto& B_access = builder.add_access(bias_block, access_B->data(), access_B->debug_info());
×
NEW
501
                auto& Y_access_in = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
NEW
502
                auto& Y_access_out = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
NEW
503
                auto& tasklet = builder.add_tasklet(
×
NEW
504
                    bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info()
×
NEW
505
                );
×
NEW
506
                builder.add_computational_memlet(
×
NEW
507
                    bias_block, Y_access_in, tasklet, "_in1", Y_subset, oedge_Y->base_type(), this->debug_info()
×
NEW
508
                );
×
NEW
509
                builder.add_computational_memlet(
×
NEW
510
                    bias_block, B_access, tasklet, "_in2", {l}, iedge_B->base_type(), iedge_B->debug_info()
×
NEW
511
                );
×
NEW
512
                builder.add_computational_memlet(
×
NEW
513
                    bias_block, tasklet, "_out", Y_access_out, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
NEW
514
                );
×
NEW
515
            }
×
UNCOV
516
        }
×
517

518
        // Free patches
519
        auto& patches_free_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
520
        {
5✔
521
            auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
522
            auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
523
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
5✔
524
            builder.add_computational_memlet(
5✔
525
                patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
5✔
526
            );
5✔
527
            builder.add_computational_memlet(
5✔
528
                patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
5✔
529
            );
5✔
530
        }
5✔
531

532
        /* ===== No groups ====================================================================== */
533

534
    } else {
5✔
535
        /* ===== Groups ========================================================================= */
536

537
        auto in_channels = symbolic::div(this->shape_[1], this->group_);
×
538
        auto out_channels = symbolic::div(this->output_channels_, this->group_);
×
539

540
        // Compute K_group = in_channels * prod(kernel_shape), spatial = prod(out_shape)
NEW
541
        symbolic::Expression gemm_k = in_channels;
×
NEW
542
        for (size_t i = 0; i < dims; i++) {
×
NEW
543
            gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
×
NEW
544
        }
×
NEW
545
        symbolic::Expression spatial = symbolic::one();
×
NEW
546
        for (size_t i = 0; i < dims; i++) {
×
NEW
547
            spatial = symbolic::mul(spatial, out_shape[i]);
×
NEW
548
        }
×
549

550
        // Patches buffer: K_group × spatial (per group iteration, reused)
NEW
551
        symbolic::Expression patches_size = symbolic::mul(gemm_k, spatial);
×
NEW
552
        types::Pointer patches_type(base_type);
×
NEW
553
        auto patches_container = builder.find_new_name("_patches");
×
NEW
554
        builder.add_container(patches_container, patches_type);
×
555

556
        // Malloc patches (outside loops)
NEW
557
        auto& patches_malloc_block = builder.add_block(new_sequence, {}, block->debug_info());
×
NEW
558
        {
×
NEW
559
            auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
×
NEW
560
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
×
NEW
561
                patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
×
NEW
562
            );
×
NEW
563
            builder.add_computational_memlet(
×
NEW
564
                patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
×
NEW
565
            );
×
NEW
566
        }
×
567

568
        // Batch loop
569
        auto n_container = builder.find_new_name("_n");
×
570
        builder.add_container(n_container, indvar_type);
×
571
        auto n = symbolic::symbol(n_container);
×
572
        auto& loop_n = builder.add_map(
×
573
            new_sequence,
×
574
            n,
×
575
            symbolic::Lt(n, this->shape_[0]),
×
576
            symbolic::zero(),
×
577
            symbolic::add(n, symbolic::one()),
×
578
            ScheduleType_Sequential::create(),
×
579
            {},
×
580
            block->debug_info()
×
581
        );
×
582

583
        // Group loop
584
        auto g_container = builder.find_new_name("_g");
×
585
        builder.add_container(g_container, indvar_type);
×
586
        auto g = symbolic::symbol(g_container);
×
587
        auto& loop_g = builder.add_map(
×
588
            loop_n.root(),
×
589
            g,
×
590
            symbolic::Lt(g, this->group_),
×
591
            symbolic::zero(),
×
592
            symbolic::add(g, symbolic::one()),
×
593
            ScheduleType_Sequential::create(),
×
594
            {},
×
595
            block->debug_info()
×
596
        );
×
597

598
        // Memset patches to zero (per batch×group iteration)
599
        auto& patches_memset_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
600
        {
×
601
            auto& patches_access = builder.add_access(patches_memset_block, patches_container, this->debug_info());
×
602
            auto& libnode = builder.add_library_node<stdlib::MemsetNode>(
×
603
                patches_memset_block,
×
604
                this->debug_info(),
×
605
                symbolic::zero(),
×
606
                symbolic::mul(patches_size, symbolic::size_of_type(base_type))
×
607
            );
×
608
            builder.add_computational_memlet(
×
609
                patches_memset_block, libnode, "_ptr", patches_access, {}, patches_type, this->debug_info()
×
610
            );
×
611
        }
×
612

613
        // Im2col loops: c, k[0..dims-1], o[0..dims-1]
614
        structured_control_flow::Sequence* current_seq = &loop_g.root();
×
615

616
        // Channel loop (over in_channels per group)
617
        auto c_container = builder.find_new_name("_c");
×
618
        builder.add_container(c_container, indvar_type);
×
619
        auto c = symbolic::symbol(c_container);
×
620
        auto& loop_c = builder.add_map(
×
621
            *current_seq,
×
622
            c,
×
623
            symbolic::Lt(c, in_channels),
×
624
            symbolic::zero(),
×
625
            symbolic::add(c, symbolic::one()),
×
626
            ScheduleType_Sequential::create(),
×
627
            {},
×
628
            block->debug_info()
×
629
        );
×
630
        current_seq = &loop_c.root();
×
631

632
        // Kernel dimension loops
633
        symbolic::SymbolVec ks;
×
634
        ks.reserve(dims);
×
635
        for (size_t i = 0; i < dims; i++) {
×
636
            auto k_container = builder.find_new_name("_k");
×
637
            builder.add_container(k_container, indvar_type);
×
638
            auto k = symbolic::symbol(k_container);
×
639
            ks.push_back(k);
×
640
            auto& loop_k = builder.add_map(
×
641
                *current_seq,
×
642
                k,
×
NEW
643
                symbolic::Lt(k, this->kernel_shape_[i]),
×
644
                symbolic::zero(),
×
645
                symbolic::add(k, symbolic::one()),
×
646
                ScheduleType_Sequential::create(),
×
647
                {},
×
648
                block->debug_info()
×
649
            );
×
650
            current_seq = &loop_k.root();
×
651
        }
×
652

653
        // Output spatial loops (with bounds check for padding/dilation)
NEW
654
        symbolic::SymbolVec os;
×
NEW
655
        os.reserve(dims);
×
NEW
656
        symbolic::MultiExpression input_indices;
×
NEW
657
        input_indices.reserve(dims);
×
NEW
658
        for (size_t i = 0; i < dims; i++) {
×
NEW
659
            auto o_container = builder.find_new_name("_o");
×
NEW
660
            builder.add_container(o_container, indvar_type);
×
NEW
661
            auto o = symbolic::symbol(o_container);
×
NEW
662
            os.push_back(o);
×
NEW
663
            auto i_expr = symbolic::
×
NEW
664
                add(symbolic::sub(symbolic::mul(o, this->strides_[i]), this->pads_[i]),
×
NEW
665
                    symbolic::mul(ks[i], this->dilations_[i]));
×
NEW
666
            input_indices.push_back(i_expr);
×
NEW
667
            auto& loop_o = builder.add_map(
×
NEW
668
                *current_seq,
×
NEW
669
                o,
×
NEW
670
                symbolic::And(symbolic::Lt(o, out_shape[i]), symbolic::Lt(i_expr, this->shape_[i + 2])),
×
NEW
671
                symbolic::zero(),
×
NEW
672
                symbolic::add(o, symbolic::one()),
×
NEW
673
                ScheduleType_Sequential::create(),
×
NEW
674
                {},
×
NEW
675
                block->debug_info()
×
NEW
676
            );
×
NEW
677
            current_seq = &loop_o.root();
×
NEW
678
        }
×
679

680
        // Patches subset: [c, k0, k1, ..., o0, o1, ...]
681
        // Patches shape: [in_channels, kernel..., out_shape...]
682
        data_flow::Subset patches_subset;
×
683
        patches_subset.push_back(c);
×
684
        patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
×
685
        patches_subset.insert(patches_subset.end(), os.begin(), os.end());
×
686
        symbolic::MultiExpression patches_shape;
×
687
        patches_shape.push_back(in_channels);
×
688
        patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
×
689
        patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
×
690
        types::Tensor patches_tensor_type(base_type, patches_shape);
×
691

692
        // X subset: [n, g * in_channels + c, i0, i1, ...]
693
        data_flow::Subset subset_X;
×
694
        subset_X.push_back(n);
×
695
        subset_X.push_back(symbolic::add(symbolic::mul(in_channels, g), c));
×
NEW
696
        subset_X.insert(subset_X.end(), input_indices.begin(), input_indices.end());
×
697

698
        // Copy X → patches
NEW
699
        auto& im2col_block = builder.add_block(*current_seq, {}, block->debug_info());
×
700
        {
×
NEW
701
            auto& X_access = builder.add_access(im2col_block, access_X->data(), access_X->debug_info());
×
NEW
702
            auto& patches_access = builder.add_access(im2col_block, patches_container, this->debug_info());
×
703
            auto& tasklet =
×
NEW
704
                builder.add_tasklet(im2col_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
×
705
            builder.add_computational_memlet(
×
NEW
706
                im2col_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
×
707
            );
×
708
            builder.add_computational_memlet(
×
NEW
709
                im2col_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
×
710
            );
×
711
        }
×
712

713
        // Reference to W[g, :, :, ...] — offset = g * out_channels * K_group
714
        auto ref_W_container = builder.find_new_name("_ref_W");
×
715
        types::Scalar ref_W_base_type(builder.subject().type(access_W->data()).primitive_type());
×
716
        types::Pointer ref_W_type(ref_W_base_type);
×
717
        builder.add_container(ref_W_container, ref_W_type);
×
NEW
718
        auto ref_W_offset = symbolic::mul(g, symbolic::mul(out_channels, gemm_k));
×
719
        auto& ref_W_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
720
        {
×
721
            auto& W_access = builder.add_access(ref_W_block, access_W->data(), access_W->debug_info());
×
722
            auto& ref_W_access = builder.add_access(ref_W_block, ref_W_container, access_W->debug_info());
×
NEW
723
            builder.add_reference_memlet(ref_W_block, W_access, ref_W_access, {ref_W_offset}, ref_W_type);
×
724
        }
×
725

726
        // Reference to Y[n, g*out_channels, 0, ...] — offset = (n * F + g * out_channels) * spatial
727
        auto ref_Y_container = builder.find_new_name("_ref_Y");
×
728
        types::Scalar ref_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
×
729
        types::Pointer ref_Y_type(ref_Y_base_type);
×
730
        builder.add_container(ref_Y_container, ref_Y_type);
×
NEW
731
        auto ref_Y_offset =
×
NEW
732
            symbolic::mul(symbolic::add(symbolic::mul(this->output_channels_, n), symbolic::mul(out_channels, g)), spatial);
×
733
        auto& ref_Y_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
734
        {
×
735
            auto& Y_access = builder.add_access(ref_Y_block, access_Y->data(), access_Y->debug_info());
×
736
            auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, access_Y->debug_info());
×
NEW
737
            builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_offset}, ref_Y_type);
×
738
        }
×
739

740
        // GEMM: Y_ref[out_channels, spatial] = W_ref[out_channels, K_group] × patches[K_group, spatial]
741
        auto& gemm_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
742
        {
×
743
            auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
×
744
            auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
×
745
            auto& ref_W_access = builder.add_access(gemm_block, ref_W_container, access_W->debug_info());
×
746
            auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
×
747
            auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
×
748
            auto& ref_Y_access_out = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
×
749
            auto& libnode = builder.add_library_node<blas::GEMMNode>(
×
750
                gemm_block,
×
751
                this->debug_info(),
×
752
                blas::ImplementationType_BLAS,
×
NEW
753
                precision,
×
NEW
754
                blas::BLAS_Layout::RowMajor,
×
755
                blas::BLAS_Transpose::No, // transA
×
756
                blas::BLAS_Transpose::No, // transB
×
NEW
757
                out_channels, // m
×
NEW
758
                spatial, // n
×
759
                gemm_k, // k
×
760
                gemm_k, // lda
×
NEW
761
                spatial, // ldb
×
NEW
762
                spatial // ldc
×
763
            );
×
764
            builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
×
765
            builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
×
766
            builder
×
767
                .add_computational_memlet(gemm_block, ref_W_access, libnode, "__A", {}, ref_W_type, iedge_W->debug_info());
×
768
            builder.add_computational_memlet(
×
769
                gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info()
×
770
            );
×
771
            builder.add_computational_memlet(
×
772
                gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, oedge_Y->debug_info()
×
773
            );
×
774
            builder.add_computational_memlet(
×
775
                gemm_block, libnode, "__C", ref_Y_access_out, {}, ref_Y_type, oedge_Y->debug_info()
×
776
            );
×
777
        }
×
778

779
        // Add bias if available: Y[n, g*out_channels + l, o...] += B[g*out_channels + l]
780
        if (has_bias) {
×
UNCOV
781
            auto l_container = builder.find_new_name("_l");
×
782
            builder.add_container(l_container, indvar_type);
×
783
            auto l = symbolic::symbol(l_container);
×
784
            auto& loop_l = builder.add_map(
×
785
                loop_g.root(),
×
786
                l,
×
787
                symbolic::Lt(l, out_channels),
×
788
                symbolic::zero(),
×
789
                symbolic::add(l, symbolic::one()),
×
790
                ScheduleType_Sequential::create(),
×
791
                {},
×
792
                block->debug_info()
×
793
            );
×
NEW
794
            structured_control_flow::Sequence* bias_seq = &loop_l.root();
×
795

NEW
796
            symbolic::SymbolVec bias_os;
×
NEW
797
            bias_os.reserve(dims);
×
798
            for (size_t i = 0; i < dims; i++) {
×
799
                auto o_container = builder.find_new_name("_o");
×
800
                builder.add_container(o_container, indvar_type);
×
801
                auto o = symbolic::symbol(o_container);
×
NEW
802
                bias_os.push_back(o);
×
803
                auto& loop_o = builder.add_map(
×
NEW
804
                    *bias_seq,
×
805
                    o,
×
806
                    symbolic::Lt(o, out_shape[i]),
×
807
                    symbolic::zero(),
×
808
                    symbolic::add(o, symbolic::one()),
×
809
                    ScheduleType_Sequential::create(),
×
810
                    {},
×
811
                    block->debug_info()
×
812
                );
×
NEW
813
                bias_seq = &loop_o.root();
×
814
            }
×
815

UNCOV
816
            data_flow::Subset Y_subset;
×
817
            Y_subset.push_back(n);
×
818
            Y_subset.push_back(symbolic::add(symbolic::mul(out_channels, g), l));
×
NEW
819
            Y_subset.insert(Y_subset.end(), bias_os.begin(), bias_os.end());
×
820
            auto B_subset = symbolic::add(symbolic::mul(out_channels, g), l);
×
821

NEW
822
            auto& bias_block = builder.add_block(*bias_seq, {}, block->debug_info());
×
823
            {
×
824
                auto& B_access = builder.add_access(bias_block, access_B->data(), access_B->debug_info());
×
825
                auto& Y_access_in = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
826
                auto& Y_access_out = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
827
                auto& tasklet = builder.add_tasklet(
×
828
                    bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info()
×
829
                );
×
830
                builder.add_computational_memlet(
×
831
                    bias_block, Y_access_in, tasklet, "_in1", Y_subset, oedge_Y->base_type(), this->debug_info()
×
832
                );
×
833
                builder.add_computational_memlet(
×
834
                    bias_block, B_access, tasklet, "_in2", {B_subset}, iedge_B->base_type(), iedge_B->debug_info()
×
835
                );
×
836
                builder.add_computational_memlet(
×
837
                    bias_block, tasklet, "_out", Y_access_out, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
838
                );
×
839
            }
×
840
        }
×
841

842
        // Free patches (outside loops)
NEW
843
        auto& patches_free_block = builder.add_block(new_sequence, {}, block->debug_info());
×
844
        {
×
845
            auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
×
846
            auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
×
847
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
×
848
            builder.add_computational_memlet(
×
849
                patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
×
850
            );
×
851
            builder.add_computational_memlet(
×
852
                patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
×
853
            );
×
854
        }
×
855

856
        /* ===== Groups ========================================================================= */
857
    }
×
858

859
    // Clean up the original block
860
    builder.remove_memlet(*block, *iedge_X);
5✔
861
    builder.remove_memlet(*block, *iedge_W);
5✔
862
    if (has_bias) {
5✔
863
        builder.remove_memlet(*block, *iedge_B);
×
864
    }
×
865
    builder.remove_memlet(*block, *oedge_Y);
5✔
866
    builder.remove_node(*block, *access_X);
5✔
867
    builder.remove_node(*block, *access_W);
5✔
868
    if (has_bias) {
5✔
869
        builder.remove_node(*block, *access_B);
×
870
    }
×
871
    builder.remove_node(*block, *access_Y);
5✔
872
    builder.remove_node(*block, *this);
5✔
873
    builder.remove_child(*block_parent, block_index + 1);
5✔
874

875
    return true;
5✔
876
}
5✔
877

878
symbolic::SymbolSet ConvNode::symbols() const {
×
879
    symbolic::SymbolSet syms;
×
880

881
    for (auto& expr : shape_) {
×
882
        for (auto& atom : symbolic::atoms(expr)) {
×
883
            syms.insert(atom);
×
884
        }
×
885
    }
×
886
    for (auto& expr : kernel_shape_) {
×
887
        for (auto& atom : symbolic::atoms(expr)) {
×
888
            syms.insert(atom);
×
889
        }
×
890
    }
×
891
    for (auto& expr : strides_) {
×
892
        for (auto& atom : symbolic::atoms(expr)) {
×
893
            syms.insert(atom);
×
894
        }
×
895
    }
×
896
    for (auto& expr : pads_) {
×
897
        for (auto& atom : symbolic::atoms(expr)) {
×
898
            syms.insert(atom);
×
899
        }
×
900
    }
×
901
    for (auto& expr : dilations_) {
×
902
        for (auto& atom : symbolic::atoms(expr)) {
×
903
            syms.insert(atom);
×
904
        }
×
905
    }
×
906
    for (auto& atom : symbolic::atoms(output_channels_)) {
×
907
        syms.insert(atom);
×
908
    }
×
909
    for (auto& atom : symbolic::atoms(group_)) {
×
910
        syms.insert(atom);
×
911
    }
×
912

913
    return syms;
×
914
}
×
915

916
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
917
    for (auto& expr : shape_) {
×
918
        expr = symbolic::subs(expr, old_expression, new_expression);
×
919
    }
×
920
    for (auto& expr : kernel_shape_) {
×
921
        expr = symbolic::subs(expr, old_expression, new_expression);
×
922
    }
×
923
    for (auto& expr : strides_) {
×
924
        expr = symbolic::subs(expr, old_expression, new_expression);
×
925
    }
×
926
    for (auto& expr : pads_) {
×
927
        expr = symbolic::subs(expr, old_expression, new_expression);
×
928
    }
×
929
    for (auto& expr : dilations_) {
×
930
        expr = symbolic::subs(expr, old_expression, new_expression);
×
931
    }
×
932
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
933
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
934
}
×
935

936
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
937
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
938
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
939
        element_id,
1✔
940
        this->debug_info(),
1✔
941
        vertex,
1✔
942
        parent,
1✔
943
        shape_,
1✔
944
        kernel_shape_,
1✔
945
        strides_,
1✔
946
        pads_,
1✔
947
        dilations_,
1✔
948
        output_channels_,
1✔
949
        group_
1✔
950
    ));
1✔
951
}
1✔
952

953
std::string ConvNode::toStr() const {
×
954
    std::stringstream result;
×
955
    result << "Conv(shape=[";
×
956
    for (size_t i = 0; i < shape_.size(); ++i) {
×
957
        if (i > 0) {
×
958
            result << ", ";
×
959
        }
×
960
        result << shape_[i]->__str__();
×
961
    }
×
962
    result << "], kernel_shape=[";
×
963
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
964
        if (i > 0) {
×
965
            result << ", ";
×
966
        }
×
967
        result << kernel_shape_[i]->__str__();
×
968
    }
×
969
    result << "], strides=[";
×
970
    for (size_t i = 0; i < strides_.size(); ++i) {
×
971
        if (i > 0) {
×
972
            result << ", ";
×
973
        }
×
974
        result << strides_[i]->__str__();
×
975
    }
×
976
    result << "], pads=[";
×
977
    for (size_t i = 0; i < pads_.size(); ++i) {
×
978
        if (i > 0) {
×
979
            result << ", ";
×
980
        }
×
981
        result << pads_[i]->__str__();
×
982
    }
×
983
    result << "], dilations=[";
×
984
    for (size_t i = 0; i < dilations_.size(); ++i) {
×
985
        if (i > 0) {
×
986
            result << ", ";
×
987
        }
×
988
        result << dilations_[i]->__str__();
×
989
    }
×
990
    result << "], output_channels=" + output_channels_->__str__();
×
991
    result << ", group=" + group_->__str__() + ")";
×
992
    return result.str();
×
993
}
×
994

995
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
996
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
997
    nlohmann::json j;
×
998

999
    j["code"] = conv_node.code().value();
×
1000

1001
    serializer::JSONSerializer serializer;
×
1002

1003
    j["shape"] = nlohmann::json::array();
×
1004
    for (auto& dim : conv_node.shape()) {
×
1005
        j["shape"].push_back(serializer.expression(dim));
×
1006
    }
×
1007

1008
    j["kernel_shape"] = nlohmann::json::array();
×
1009
    for (auto& dim : conv_node.kernel_shape()) {
×
1010
        j["kernel_shape"].push_back(serializer.expression(dim));
×
1011
    }
×
1012

1013
    j["strides"] = nlohmann::json::array();
×
1014
    for (auto& stride : conv_node.strides()) {
×
1015
        j["strides"].push_back(serializer.expression(stride));
×
1016
    }
×
1017

1018
    j["pads"] = nlohmann::json::array();
×
1019
    for (auto& pad : conv_node.pads()) {
×
1020
        j["pads"].push_back(serializer.expression(pad));
×
1021
    }
×
1022

1023
    j["dilations"] = nlohmann::json::array();
×
1024
    for (auto& dilation : conv_node.dilations()) {
×
1025
        j["dilations"].push_back(serializer.expression(dilation));
×
1026
    }
×
1027

1028
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
1029
    j["group"] = serializer.expression(conv_node.group());
×
1030

1031
    return j;
×
1032
}
×
1033

1034
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
1035
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
1036
) {
×
1037
    assert(j.contains("element_id"));
×
1038
    assert(j.contains("code"));
×
1039
    assert(j.contains("debug_info"));
×
1040
    assert(j.contains("kernel_shape"));
×
1041

1042
    std::vector<symbolic::Expression> shape;
×
1043
    if (j.contains("shape")) {
×
1044
        for (const auto& dim : j["shape"]) {
×
1045
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
1046
        }
×
1047
    }
×
1048

1049
    std::vector<symbolic::Expression> kernel_shape;
×
1050
    for (const auto& dim : j["kernel_shape"]) {
×
1051
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
1052
    }
×
1053

1054
    std::vector<symbolic::Expression> strides;
×
1055
    if (j.contains("strides")) {
×
1056
        for (const auto& stride : j["strides"]) {
×
1057
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
1058
        }
×
1059
    }
×
1060

1061
    std::vector<symbolic::Expression> pads;
×
1062
    if (j.contains("pads")) {
×
1063
        for (const auto& pad : j["pads"]) {
×
1064
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
1065
        }
×
1066
    }
×
1067

1068
    std::vector<symbolic::Expression> dilations;
×
1069
    if (j.contains("dilations")) {
×
1070
        for (const auto& dilation : j["dilations"]) {
×
1071
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
1072
        }
×
1073
    }
×
1074

1075
    symbolic::Expression output_channels = symbolic::one();
×
1076
    if (j.contains("output_channels")) {
×
1077
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
1078
    }
×
1079

1080
    symbolic::Expression group = symbolic::one();
×
1081
    if (j.contains("group")) {
×
1082
        group = symbolic::parse(j["group"].get<std::string>());
×
1083
    }
×
1084

1085
    sdfg::serializer::JSONSerializer serializer;
×
1086
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
1087

1088
    return builder.add_library_node<
×
1089
        ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, output_channels, group);
×
1090
}
×
1091

1092
} // namespace tensor
1093
} // namespace math
1094
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc