• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 24178515126

09 Apr 2026 07:39AM UTC coverage: 64.863% (-0.1%) from 64.986%
24178515126

Pull #663

github

web-flow
Merge d88e57c14 into 223814883
Pull Request #663: expands conv node patch-wise

89 of 203 new or added lines in 1 file covered. (43.84%)

8 existing lines in 1 file now uncovered.

29147 of 44936 relevant lines covered (64.86%)

602.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

37.66
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4
#include <sstream>
5

6
#include "sdfg/analysis/analysis.h"
7
#include "sdfg/builder/structured_sdfg_builder.h"
8
#include "sdfg/data_flow/access_node.h"
9
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
10
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
11
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
12
#include "sdfg/data_flow/library_nodes/stdlib/memset.h"
13
#include "sdfg/data_flow/memlet.h"
14
#include "sdfg/data_flow/tasklet.h"
15
#include "sdfg/exceptions.h"
16
#include "sdfg/structured_control_flow/block.h"
17
#include "sdfg/structured_control_flow/map.h"
18
#include "sdfg/structured_control_flow/sequence.h"
19
#include "sdfg/symbolic/symbolic.h"
20
#include "sdfg/types/pointer.h"
21
#include "sdfg/types/scalar.h"
22
#include "sdfg/types/tensor.h"
23
#include "sdfg/types/type.h"
24

25
#include "sdfg/analysis/scope_analysis.h"
26
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
27
#include "symengine/integer.h"
28
#include "symengine/symengine_rcp.h"
29

30
namespace sdfg {
31
namespace math {
32
namespace tensor {
33

34
ConvNode::ConvNode(
35
    size_t element_id,
36
    const DebugInfo& debug_info,
37
    const graph::Vertex vertex,
38
    data_flow::DataFlowGraph& parent,
39
    const std::vector<symbolic::Expression>& shape,
40
    const std::vector<symbolic::Expression>& kernel_shape,
41
    const std::vector<symbolic::Expression>& strides,
42
    const std::vector<symbolic::Expression>& pads,
43
    const std::vector<symbolic::Expression>& dilations,
44
    symbolic::Expression output_channels,
45
    symbolic::Expression group
46
)
47
    : TensorNode(
22✔
48
          element_id,
22✔
49
          debug_info,
22✔
50
          vertex,
22✔
51
          parent,
22✔
52
          LibraryNodeType_Conv,
22✔
53
          {"Y"},
22✔
54
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
22✔
55
          data_flow::ImplementationType_NONE
22✔
56
      ),
22✔
57
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations),
22✔
58
      output_channels_(output_channels), group_(group) {}
22✔
59

60
void ConvNode::validate(const Function& function) const {
17✔
61
    TensorNode::validate(function);
17✔
62

63
    auto& graph = this->get_parent();
17✔
64

65
    // Custom validation for ConvNode that handles optional bias input
66
    // We expect X, W as required inputs and optionally B (bias)
67

68
    // Collect all input edges by connector name
69
    std::map<std::string, const data_flow::Memlet*> input_edges;
17✔
70
    for (auto& iedge : graph.in_edges(*this)) {
34✔
71
        input_edges[iedge.dst_conn()] = &iedge;
34✔
72
    }
34✔
73

74
    // Check that required inputs X and W are present
75
    if (input_edges.find("X") == input_edges.end()) {
17✔
76
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
77
    }
×
78
    if (input_edges.find("W") == input_edges.end()) {
17✔
79
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
80
    }
×
81

82
    // Validate that parameters are not empty
83
    if (shape_.empty()) {
17✔
84
        throw InvalidSDFGException("ConvNode shape cannot be empty");
×
85
    }
×
86
    if (kernel_shape_.empty()) {
17✔
87
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
88
    }
×
89
    if (strides_.empty()) {
17✔
90
        throw InvalidSDFGException("ConvNode strides cannot be empty");
×
91
    }
×
92
    if (pads_.empty()) {
17✔
93
        throw InvalidSDFGException("ConvNode pads cannot be empty");
×
94
    }
×
95
    if (dilations_.empty()) {
17✔
96
        throw InvalidSDFGException("ConvNode dilations cannot be empty");
×
97
    }
×
98

99
    // Validate consistent dimensions
100
    size_t spatial_dims = kernel_shape_.size();
17✔
101

102
    if (shape_.size() != spatial_dims + 2) {
17✔
103
        throw InvalidSDFGException("ConvNode shape must match kernel spatial dimensions + 2");
×
104
    }
×
105

106
    if (strides_.size() != spatial_dims) {
17✔
107
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
108
    }
1✔
109

110
    if (pads_.size() != 2 * spatial_dims) {
16✔
111
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
112
    }
1✔
113

114
    if (dilations_.size() != spatial_dims) {
15✔
115
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
116
    }
×
117

118
    // Validate groups
119
    if (SymEngine::is_a<SymEngine::Integer>(*this->group_)) {
15✔
120
        auto group_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->group_)->as_int();
15✔
121
        if (SymEngine::is_a<SymEngine::Integer>(*this->shape_[1])) {
15✔
122
            auto input_channels_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->shape_[1])->as_int();
15✔
123
            if (input_channels_int % group_int != 0) {
15✔
124
                throw InvalidSDFGException("ConvNode input channels must be divisible by groups");
×
125
            }
×
126
        }
15✔
127
        if (SymEngine::is_a<SymEngine::Integer>(*this->output_channels_)) {
15✔
128
            auto output_channels_int =
15✔
129
                SymEngine::rcp_static_cast<const SymEngine::Integer>(this->output_channels_)->as_int();
15✔
130
            if (output_channels_int % group_int != 0) {
15✔
131
                throw InvalidSDFGException("ConvNode output channels must be divisible by groups");
×
132
            }
×
133
        }
15✔
134
    }
15✔
135
}
15✔
136

137
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
138
    // Validate nodes are standalone in the data flow graph
139
    auto& dfg = this->get_parent();
5✔
140
    if ((dfg.nodes().size() != 4 || dfg.edges().size() != 3) && (dfg.nodes().size() != 5 || dfg.edges().size() != 4)) {
5✔
141
        return false;
×
142
    }
×
143

144
    // Get edges
145
    auto iedges = dfg.in_edges_by_connector(*this);
5✔
146
    auto oedges = dfg.out_edges_by_connector(*this);
5✔
147
    if (iedges.size() != 3 || oedges.size() != 1) {
5✔
148
        return false;
×
149
    }
×
150
    auto* iedge_X = iedges.at(0);
5✔
151
    auto* iedge_W = iedges.at(1);
5✔
152
    auto* iedge_B = iedges.at(2);
5✔
153
    auto* oedge_Y = oedges.at(0);
5✔
154
    if (!iedge_X || !iedge_W || !oedge_Y) {
5✔
155
        return false;
×
156
    }
×
157
    bool has_bias = iedge_B != nullptr;
5✔
158

159
    // Get access nodes
160
    auto* access_X = dynamic_cast<data_flow::AccessNode*>(&iedge_X->src());
5✔
161
    auto* access_W = dynamic_cast<data_flow::AccessNode*>(&iedge_W->src());
5✔
162
    auto* access_B = (has_bias ? dynamic_cast<data_flow::AccessNode*>(&iedge_B->src()) : nullptr);
5✔
163
    auto* access_Y = dynamic_cast<data_flow::AccessNode*>(&oedge_Y->dst());
5✔
164
    if (!access_X || !access_W || (has_bias && !access_B) || !access_Y) {
5✔
165
        return false;
×
166
    }
×
167

168
    // Get block & its parent
169
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
5✔
170
    if (!block) {
5✔
171
        return false;
×
172
    }
×
173
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
174
    auto* block_parent = dynamic_cast<structured_control_flow::Sequence*>(scope_analysis.parent_scope(block));
5✔
175
    if (!block_parent) {
5✔
176
        return false;
×
177
    }
×
178
    size_t block_index = block_parent->index(*block);
5✔
179
    if (block_index >= block_parent->size()) {
5✔
180
        return false;
×
181
    }
×
182

183
    // Determine BLAS precision
184
    blas::BLAS_Precision precision;
5✔
185
    types::Scalar base_type(this->primitive_type(dfg));
5✔
186
    switch (base_type.primitive_type()) {
5✔
187
        case types::PrimitiveType::Half:
×
188
            precision = blas::BLAS_Precision::h;
×
189
            break;
×
190
        case types::PrimitiveType::Float:
5✔
191
            precision = blas::BLAS_Precision::s;
5✔
192
            break;
5✔
193
        case types::PrimitiveType::Double:
×
194
            precision = blas::BLAS_Precision::d;
×
195
            break;
×
196
        default:
×
197
            return false;
×
198
    }
5✔
199

200
    // Create new sequence for expansion
201
    auto& new_sequence = builder.add_sequence_before(
5✔
202
        *block_parent, *block, block_parent->at(block_index).second.assignments(), block->debug_info()
5✔
203
    );
5✔
204

205
    // Dimensions, i.e., 1D, 2D, 3D, ...
206
    size_t dims = this->kernel_shape_.size();
5✔
207
    symbolic::MultiExpression out_shape;
5✔
208
    out_shape.reserve(dims);
5✔
209
    // out_shape[i] = (shape[i + 2] + pads[i] + pads[dims + i] - dilations[i] * (kernel_shape[i] - 1) - 1)
210
    //                 / strides[i] + 1
211
    for (size_t i = 0; i < dims; i++) {
15✔
212
        out_shape.push_back(symbolic::add(
10✔
213
            symbolic::div(
10✔
214
                symbolic::sub(
10✔
215
                    symbolic::
10✔
216
                        sub(symbolic::add(this->shape_[i + 2], symbolic::add(this->pads_[i], this->pads_[dims + i])),
10✔
217
                            symbolic::mul(this->dilations_[i], symbolic::sub(this->kernel_shape_[i], symbolic::one()))),
10✔
218
                    symbolic::one()
10✔
219
                ),
10✔
220
                this->strides_[i]
10✔
221
            ),
10✔
222
            symbolic::one()
10✔
223
        ));
10✔
224
    }
10✔
225
    types::Scalar indvar_type(types::PrimitiveType::Int64);
5✔
226

227
    // Compute spatial dimensions: K = C_in * prod(kernel_shape), spatial = prod(out_shape)
228
    // Patches layout: [K, spatial] = [C_in, kernel..., out_shape...]
229
    // GEMM: Y[F, spatial] = W[F, K] × patches[K, spatial]  (NoTrans × NoTrans)
230
    // Output Y is [N, F, out_shape...] — GEMM writes directly per batch via pointer reference
231

232
    if (symbolic::eq(this->group_, symbolic::one())) {
5✔
233
        /* ===== No groups ====================================================================== */
234

235
        // Compute K = C_in * prod(kernel_shape) and spatial = prod(out_shape)
236
        symbolic::Expression gemm_k = this->shape_[1]; // C_in
5✔
237
        for (size_t i = 0; i < dims; i++) {
15✔
238
            gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
10✔
239
        }
10✔
240
        symbolic::Expression spatial = symbolic::one();
5✔
241
        for (size_t i = 0; i < dims; i++) {
15✔
242
            spatial = symbolic::mul(spatial, out_shape[i]);
10✔
243
        }
10✔
244

245
        // Patches buffer: K × spatial (per-batch, reused across batches)
246
        symbolic::Expression patches_size = symbolic::mul(gemm_k, spatial);
5✔
247
        types::Pointer patches_type(base_type);
5✔
248
        auto patches_container = builder.find_new_name("_patches");
5✔
249
        builder.add_container(patches_container, patches_type);
5✔
250

251
        // Batch loop
252
        auto n_container = builder.find_new_name("_n");
5✔
253
        builder.add_container(n_container, indvar_type);
5✔
254
        auto n = symbolic::symbol(n_container);
5✔
255
        auto& loop_n = builder.add_map(
5✔
256
            new_sequence,
5✔
257
            n,
5✔
258
            symbolic::Lt(n, this->shape_[0]),
5✔
259
            symbolic::zero(),
5✔
260
            symbolic::add(n, symbolic::one()),
5✔
261
            ScheduleType_Sequential::create(),
5✔
262
            {},
5✔
263
            block->debug_info()
5✔
264
        );
5✔
265

266
        // Malloc patches
267
        auto& patches_malloc_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
268
        {
5✔
269
            auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
5✔
270
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
5✔
271
                patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
5✔
272
            );
5✔
273
            builder.add_computational_memlet(
5✔
274
                patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
5✔
275
            );
5✔
276
        }
5✔
277

278
        // Memset patches to zero (inside batch loop, per-batch)
279
        auto& patches_memset_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
280
        {
5✔
281
            auto& patches_access = builder.add_access(patches_memset_block, patches_container, this->debug_info());
5✔
282
            auto& libnode = builder.add_library_node<stdlib::MemsetNode>(
5✔
283
                patches_memset_block,
5✔
284
                this->debug_info(),
5✔
285
                symbolic::zero(),
5✔
286
                symbolic::mul(patches_size, symbolic::size_of_type(base_type))
5✔
287
            );
5✔
288
            builder.add_computational_memlet(
5✔
289
                patches_memset_block, libnode, "_ptr", patches_access, {}, patches_type, this->debug_info()
5✔
290
            );
5✔
291
        }
5✔
292

293
        // Im2col: nested loops over channels, kernel, output positions
294
        // Loop order: c, k[0..dims-1], o[0..dims-1]
295
        // Patches layout: [C_in, kernel..., out_shape...] — row-major gives [K, spatial]
296
        structured_control_flow::Sequence* current_seq = &loop_n.root();
5✔
297

298
        // Channel loop
299
        auto c_container = builder.find_new_name("_c");
5✔
300
        builder.add_container(c_container, indvar_type);
5✔
301
        auto c = symbolic::symbol(c_container);
5✔
302
        auto& loop_c = builder.add_map(
5✔
303
            *current_seq,
5✔
304
            c,
5✔
305
            symbolic::Lt(c, this->shape_[1]),
5✔
306
            symbolic::zero(),
5✔
307
            symbolic::add(c, symbolic::one()),
5✔
308
            ScheduleType_Sequential::create(),
5✔
309
            {},
5✔
310
            block->debug_info()
5✔
311
        );
5✔
312
        current_seq = &loop_c.root();
5✔
313

314
        // Kernel dimension loops
315
        symbolic::SymbolVec ks;
5✔
316
        ks.reserve(dims);
5✔
317
        symbolic::MultiExpression input_indices; // i_expr for each spatial dim
5✔
318
        input_indices.reserve(dims);
5✔
319
        for (size_t i = 0; i < dims; i++) {
15✔
320
            auto k_container = builder.find_new_name("_k");
10✔
321
            builder.add_container(k_container, indvar_type);
10✔
322
            auto k = symbolic::symbol(k_container);
10✔
323
            ks.push_back(k);
10✔
324
            auto& loop_k = builder.add_map(
10✔
325
                *current_seq,
10✔
326
                k,
10✔
327
                symbolic::Lt(k, this->kernel_shape_[i]),
10✔
328
                symbolic::zero(),
10✔
329
                symbolic::add(k, symbolic::one()),
10✔
330
                ScheduleType_Sequential::create(),
10✔
331
                {},
10✔
332
                block->debug_info()
10✔
333
            );
10✔
334
            current_seq = &loop_k.root();
10✔
335
        }
10✔
336

337
        // Output spatial dimension loops (with padding/dilation bounds check)
338
        symbolic::SymbolVec os;
5✔
339
        os.reserve(dims);
5✔
340
        for (size_t i = 0; i < dims; i++) {
15✔
341
            auto o_container = builder.find_new_name("_o");
10✔
342
            builder.add_container(o_container, indvar_type);
10✔
343
            auto o = symbolic::symbol(o_container);
10✔
344
            os.push_back(o);
10✔
345
            // i_expr = o * stride - pad + k * dilation
346
            auto i_expr = symbolic::
10✔
347
                add(symbolic::sub(symbolic::mul(o, this->strides_[i]), this->pads_[i]),
10✔
348
                    symbolic::mul(ks[i], this->dilations_[i]));
10✔
349
            input_indices.push_back(i_expr);
10✔
350
            auto& loop_o = builder.add_map(
10✔
351
                *current_seq,
10✔
352
                o,
10✔
353
                symbolic::And(symbolic::Lt(o, out_shape[i]), symbolic::Lt(i_expr, this->shape_[i + 2])),
10✔
354
                symbolic::zero(),
10✔
355
                symbolic::add(o, symbolic::one()),
10✔
356
                ScheduleType_Sequential::create(),
10✔
357
                {},
10✔
358
                block->debug_info()
10✔
359
            );
10✔
360
            current_seq = &loop_o.root();
10✔
361
        }
10✔
362

363
        // Patches subset: [c, k0, k1, ..., o0, o1, ...]
364
        // Patches shape: [C_in, kH, kW, ..., H_out, W_out, ...]
365
        data_flow::Subset patches_subset;
5✔
366
        patches_subset.push_back(c);
5✔
367
        patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
5✔
368
        patches_subset.insert(patches_subset.end(), os.begin(), os.end());
5✔
369
        symbolic::MultiExpression patches_shape;
5✔
370
        patches_shape.push_back(this->shape_[1]); // C_in
5✔
371
        patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
5✔
372
        patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
5✔
373
        types::Tensor patches_tensor_type(base_type, patches_shape);
5✔
374

375
        // X subset: [n, c, i0, i1, ...]
376
        data_flow::Subset subset_X;
5✔
377
        subset_X.push_back(n);
5✔
378
        subset_X.push_back(c);
5✔
379
        subset_X.insert(subset_X.end(), input_indices.begin(), input_indices.end());
5✔
380

381
        // Copy X[n, c, i...] → patches[c, k..., o...]
382
        auto& im2col_block = builder.add_block(*current_seq, {}, block->debug_info());
5✔
383
        {
5✔
384
            auto& X_access = builder.add_access(im2col_block, access_X->data(), access_X->debug_info());
5✔
385
            auto& patches_access = builder.add_access(im2col_block, patches_container, this->debug_info());
5✔
386
            auto& tasklet =
5✔
387
                builder.add_tasklet(im2col_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
388
            builder.add_computational_memlet(
5✔
389
                im2col_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
5✔
390
            );
5✔
391
            builder.add_computational_memlet(
5✔
392
                im2col_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
393
            );
5✔
394
        }
5✔
395

396
        // Reference to Y[n, :, :, ...] — offset = n * F * spatial
397
        auto ref_Y_container = builder.find_new_name("_ref_Y");
5✔
398
        types::Scalar ref_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
5✔
399
        types::Pointer ref_Y_type(ref_Y_base_type);
5✔
400
        builder.add_container(ref_Y_container, ref_Y_type);
5✔
401
        auto ref_Y_offset = symbolic::mul(n, symbolic::mul(this->output_channels_, spatial));
5✔
402
        auto& ref_Y_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
403
        {
5✔
404
            auto& Y_access = builder.add_access(ref_Y_block, access_Y->data(), access_Y->debug_info());
5✔
405
            auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, access_Y->debug_info());
5✔
406
            builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_offset}, ref_Y_type);
5✔
407
        }
5✔
408

409
        // GEMM: Y[n][F, spatial] = W[F, K] × patches[K, spatial]
410
        // NoTrans × NoTrans, lda=K, ldb=spatial, ldc=spatial
411
        auto& gemm_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
412
        {
5✔
413
            auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
5✔
414
            auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
5✔
415
            auto& W_access = builder.add_access(gemm_block, access_W->data(), access_W->debug_info());
5✔
416
            auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
5✔
417
            auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
5✔
418
            auto& ref_Y_access_out = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
5✔
419
            auto& libnode = builder.add_library_node<blas::GEMMNode>(
5✔
420
                gemm_block,
5✔
421
                this->debug_info(),
5✔
422
                blas::ImplementationType_BLAS,
5✔
423
                precision,
5✔
424
                blas::BLAS_Layout::RowMajor,
5✔
425
                blas::BLAS_Transpose::No, // transA
5✔
426
                blas::BLAS_Transpose::No, // transB
5✔
427
                this->output_channels_, // m = F
5✔
428
                spatial, // n = spatial
5✔
429
                gemm_k, // k = K
5✔
430
                gemm_k, // lda = K
5✔
431
                spatial, // ldb = spatial
5✔
432
                spatial // ldc = spatial
5✔
433
            );
5✔
434
            builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
5✔
435
            builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
5✔
436
            builder.add_computational_memlet(
5✔
437
                gemm_block,
5✔
438
                W_access,
5✔
439
                libnode,
5✔
440
                "__A",
5✔
441
                {},
5✔
442
                types::Pointer(types::Scalar(iedge_W->base_type().primitive_type())),
5✔
443
                iedge_W->debug_info()
5✔
444
            );
5✔
445
            builder.add_computational_memlet(
5✔
446
                gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info()
5✔
447
            );
5✔
448
            builder.add_computational_memlet(
5✔
449
                gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, oedge_Y->debug_info()
5✔
450
            );
5✔
451
            builder.add_computational_memlet(
5✔
452
                gemm_block, libnode, "__C", ref_Y_access_out, {}, ref_Y_type, oedge_Y->debug_info()
5✔
453
            );
5✔
454
        }
5✔
455

456
        // Add bias if available: Y[n, f, o...] += B[f]
457
        if (has_bias) {
5✔
NEW
458
            auto l_container = builder.find_new_name("_l");
×
NEW
459
            builder.add_container(l_container, indvar_type);
×
NEW
460
            auto l = symbolic::symbol(l_container);
×
NEW
461
            auto& loop_l = builder.add_map(
×
NEW
462
                loop_n.root(),
×
NEW
463
                l,
×
NEW
464
                symbolic::Lt(l, this->output_channels_),
×
UNCOV
465
                symbolic::zero(),
×
NEW
466
                symbolic::add(l, symbolic::one()),
×
UNCOV
467
                ScheduleType_Sequential::create(),
×
UNCOV
468
                {},
×
UNCOV
469
                block->debug_info()
×
UNCOV
470
            );
×
NEW
471
            structured_control_flow::Sequence* bias_seq = &loop_l.root();
×
472

NEW
473
            symbolic::SymbolVec bias_os;
×
NEW
474
            bias_os.reserve(dims);
×
NEW
475
            for (size_t i = 0; i < dims; i++) {
×
NEW
476
                auto o_container = builder.find_new_name("_o");
×
NEW
477
                builder.add_container(o_container, indvar_type);
×
NEW
478
                auto o = symbolic::symbol(o_container);
×
NEW
479
                bias_os.push_back(o);
×
NEW
480
                auto& loop_o = builder.add_map(
×
NEW
481
                    *bias_seq,
×
NEW
482
                    o,
×
NEW
483
                    symbolic::Lt(o, out_shape[i]),
×
NEW
484
                    symbolic::zero(),
×
NEW
485
                    symbolic::add(o, symbolic::one()),
×
NEW
486
                    ScheduleType_Sequential::create(),
×
NEW
487
                    {},
×
NEW
488
                    block->debug_info()
×
NEW
489
                );
×
NEW
490
                bias_seq = &loop_o.root();
×
NEW
491
            }
×
492

NEW
493
            data_flow::Subset Y_subset;
×
NEW
494
            Y_subset.push_back(n);
×
NEW
495
            Y_subset.push_back(l);
×
NEW
496
            Y_subset.insert(Y_subset.end(), bias_os.begin(), bias_os.end());
×
497

NEW
498
            auto& bias_block = builder.add_block(*bias_seq, {}, block->debug_info());
×
NEW
499
            {
×
NEW
500
                auto& B_access = builder.add_access(bias_block, access_B->data(), access_B->debug_info());
×
NEW
501
                auto& Y_access_in = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
NEW
502
                auto& Y_access_out = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
NEW
503
                auto& tasklet = builder.add_tasklet(
×
NEW
504
                    bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info()
×
NEW
505
                );
×
NEW
506
                builder.add_computational_memlet(
×
NEW
507
                    bias_block, Y_access_in, tasklet, "_in1", Y_subset, oedge_Y->base_type(), this->debug_info()
×
NEW
508
                );
×
NEW
509
                builder.add_computational_memlet(
×
NEW
510
                    bias_block, B_access, tasklet, "_in2", {l}, iedge_B->base_type(), iedge_B->debug_info()
×
NEW
511
                );
×
NEW
512
                builder.add_computational_memlet(
×
NEW
513
                    bias_block, tasklet, "_out", Y_access_out, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
NEW
514
                );
×
NEW
515
            }
×
UNCOV
516
        }
×
517

518
        // Free patches
519
        auto& patches_free_block = builder.add_block(loop_n.root(), {}, block->debug_info());
5✔
520
        {
5✔
521
            auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
522
            auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
523
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
5✔
524
            builder.add_computational_memlet(
5✔
525
                patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
5✔
526
            );
5✔
527
            builder.add_computational_memlet(
5✔
528
                patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
5✔
529
            );
5✔
530
        }
5✔
531

532
        /* ===== No groups ====================================================================== */
533

534
    } else {
5✔
535
        /* ===== Groups ========================================================================= */
536

537
        auto in_channels = symbolic::div(this->shape_[1], this->group_);
×
538
        auto out_channels = symbolic::div(this->output_channels_, this->group_);
×
539

540
        // Compute K_group = in_channels * prod(kernel_shape), spatial = prod(out_shape)
NEW
541
        symbolic::Expression gemm_k = in_channels;
×
NEW
542
        for (size_t i = 0; i < dims; i++) {
×
NEW
543
            gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
×
NEW
544
        }
×
NEW
545
        symbolic::Expression spatial = symbolic::one();
×
NEW
546
        for (size_t i = 0; i < dims; i++) {
×
NEW
547
            spatial = symbolic::mul(spatial, out_shape[i]);
×
NEW
548
        }
×
549

550
        // Patches buffer: K_group × spatial (per group iteration, reused)
NEW
551
        symbolic::Expression patches_size = symbolic::mul(gemm_k, spatial);
×
NEW
552
        types::Pointer patches_type(base_type);
×
NEW
553
        auto patches_container = builder.find_new_name("_patches");
×
NEW
554
        builder.add_container(patches_container, patches_type);
×
555

556
        // Batch loop
557
        auto n_container = builder.find_new_name("_n");
×
558
        builder.add_container(n_container, indvar_type);
×
559
        auto n = symbolic::symbol(n_container);
×
560
        auto& loop_n = builder.add_map(
×
561
            new_sequence,
×
562
            n,
×
563
            symbolic::Lt(n, this->shape_[0]),
×
564
            symbolic::zero(),
×
565
            symbolic::add(n, symbolic::one()),
×
566
            ScheduleType_Sequential::create(),
×
567
            {},
×
568
            block->debug_info()
×
569
        );
×
570

571
        // Group loop
572
        auto g_container = builder.find_new_name("_g");
×
573
        builder.add_container(g_container, indvar_type);
×
574
        auto g = symbolic::symbol(g_container);
×
575
        auto& loop_g = builder.add_map(
×
576
            loop_n.root(),
×
577
            g,
×
578
            symbolic::Lt(g, this->group_),
×
579
            symbolic::zero(),
×
580
            symbolic::add(g, symbolic::one()),
×
581
            ScheduleType_Sequential::create(),
×
582
            {},
×
583
            block->debug_info()
×
584
        );
×
585

586

587
        // Malloc patches (outside loops)
588
        auto& patches_malloc_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
589
        {
×
590
            auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
×
591
            auto& libnode = builder.add_library_node<stdlib::MallocNode>(
×
592
                patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
×
593
            );
×
594
            builder.add_computational_memlet(
×
595
                patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
×
596
            );
×
597
        }
×
598

599
        // Memset patches to zero (per batch×group iteration)
600
        auto& patches_memset_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
601
        {
×
602
            auto& patches_access = builder.add_access(patches_memset_block, patches_container, this->debug_info());
×
603
            auto& libnode = builder.add_library_node<stdlib::MemsetNode>(
×
604
                patches_memset_block,
×
605
                this->debug_info(),
×
606
                symbolic::zero(),
×
607
                symbolic::mul(patches_size, symbolic::size_of_type(base_type))
×
608
            );
×
609
            builder.add_computational_memlet(
×
610
                patches_memset_block, libnode, "_ptr", patches_access, {}, patches_type, this->debug_info()
×
611
            );
×
612
        }
×
613

614
        // Im2col loops: c, k[0..dims-1], o[0..dims-1]
615
        structured_control_flow::Sequence* current_seq = &loop_g.root();
×
616

617
        // Channel loop (over in_channels per group)
618
        auto c_container = builder.find_new_name("_c");
×
619
        builder.add_container(c_container, indvar_type);
×
620
        auto c = symbolic::symbol(c_container);
×
621
        auto& loop_c = builder.add_map(
×
622
            *current_seq,
×
623
            c,
×
624
            symbolic::Lt(c, in_channels),
×
625
            symbolic::zero(),
×
626
            symbolic::add(c, symbolic::one()),
×
627
            ScheduleType_Sequential::create(),
×
628
            {},
×
629
            block->debug_info()
×
630
        );
×
631
        current_seq = &loop_c.root();
×
632

633
        // Kernel dimension loops
634
        symbolic::SymbolVec ks;
×
635
        ks.reserve(dims);
×
636
        for (size_t i = 0; i < dims; i++) {
×
637
            auto k_container = builder.find_new_name("_k");
×
638
            builder.add_container(k_container, indvar_type);
×
639
            auto k = symbolic::symbol(k_container);
×
640
            ks.push_back(k);
×
641
            auto& loop_k = builder.add_map(
×
642
                *current_seq,
×
643
                k,
×
NEW
644
                symbolic::Lt(k, this->kernel_shape_[i]),
×
645
                symbolic::zero(),
×
646
                symbolic::add(k, symbolic::one()),
×
647
                ScheduleType_Sequential::create(),
×
648
                {},
×
649
                block->debug_info()
×
650
            );
×
651
            current_seq = &loop_k.root();
×
652
        }
×
653

654
        // Output spatial loops (with bounds check for padding/dilation)
NEW
655
        symbolic::SymbolVec os;
×
NEW
656
        os.reserve(dims);
×
NEW
657
        symbolic::MultiExpression input_indices;
×
NEW
658
        input_indices.reserve(dims);
×
NEW
659
        for (size_t i = 0; i < dims; i++) {
×
NEW
660
            auto o_container = builder.find_new_name("_o");
×
NEW
661
            builder.add_container(o_container, indvar_type);
×
NEW
662
            auto o = symbolic::symbol(o_container);
×
NEW
663
            os.push_back(o);
×
NEW
664
            auto i_expr = symbolic::
×
NEW
665
                add(symbolic::sub(symbolic::mul(o, this->strides_[i]), this->pads_[i]),
×
NEW
666
                    symbolic::mul(ks[i], this->dilations_[i]));
×
NEW
667
            input_indices.push_back(i_expr);
×
NEW
668
            auto& loop_o = builder.add_map(
×
NEW
669
                *current_seq,
×
NEW
670
                o,
×
NEW
671
                symbolic::And(symbolic::Lt(o, out_shape[i]), symbolic::Lt(i_expr, this->shape_[i + 2])),
×
NEW
672
                symbolic::zero(),
×
NEW
673
                symbolic::add(o, symbolic::one()),
×
NEW
674
                ScheduleType_Sequential::create(),
×
NEW
675
                {},
×
NEW
676
                block->debug_info()
×
NEW
677
            );
×
NEW
678
            current_seq = &loop_o.root();
×
NEW
679
        }
×
680

681
        // Patches subset: [c, k0, k1, ..., o0, o1, ...]
682
        // Patches shape: [in_channels, kernel..., out_shape...]
683
        data_flow::Subset patches_subset;
×
684
        patches_subset.push_back(c);
×
685
        patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
×
686
        patches_subset.insert(patches_subset.end(), os.begin(), os.end());
×
687
        symbolic::MultiExpression patches_shape;
×
688
        patches_shape.push_back(in_channels);
×
689
        patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
×
690
        patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
×
691
        types::Tensor patches_tensor_type(base_type, patches_shape);
×
692

693
        // X subset: [n, g * in_channels + c, i0, i1, ...]
694
        data_flow::Subset subset_X;
×
695
        subset_X.push_back(n);
×
696
        subset_X.push_back(symbolic::add(symbolic::mul(in_channels, g), c));
×
NEW
697
        subset_X.insert(subset_X.end(), input_indices.begin(), input_indices.end());
×
698

699
        // Copy X → patches
NEW
700
        auto& im2col_block = builder.add_block(*current_seq, {}, block->debug_info());
×
701
        {
×
NEW
702
            auto& X_access = builder.add_access(im2col_block, access_X->data(), access_X->debug_info());
×
NEW
703
            auto& patches_access = builder.add_access(im2col_block, patches_container, this->debug_info());
×
704
            auto& tasklet =
×
NEW
705
                builder.add_tasklet(im2col_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
×
706
            builder.add_computational_memlet(
×
NEW
707
                im2col_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
×
708
            );
×
709
            builder.add_computational_memlet(
×
NEW
710
                im2col_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
×
711
            );
×
712
        }
×
713

714
        // Reference to W[g, :, :, ...] — offset = g * out_channels * K_group
715
        auto ref_W_container = builder.find_new_name("_ref_W");
×
716
        types::Scalar ref_W_base_type(builder.subject().type(access_W->data()).primitive_type());
×
717
        types::Pointer ref_W_type(ref_W_base_type);
×
718
        builder.add_container(ref_W_container, ref_W_type);
×
NEW
719
        auto ref_W_offset = symbolic::mul(g, symbolic::mul(out_channels, gemm_k));
×
720
        auto& ref_W_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
721
        {
×
722
            auto& W_access = builder.add_access(ref_W_block, access_W->data(), access_W->debug_info());
×
723
            auto& ref_W_access = builder.add_access(ref_W_block, ref_W_container, access_W->debug_info());
×
NEW
724
            builder.add_reference_memlet(ref_W_block, W_access, ref_W_access, {ref_W_offset}, ref_W_type);
×
725
        }
×
726

727
        // Reference to Y[n, g*out_channels, 0, ...] — offset = (n * F + g * out_channels) * spatial
728
        auto ref_Y_container = builder.find_new_name("_ref_Y");
×
729
        types::Scalar ref_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
×
730
        types::Pointer ref_Y_type(ref_Y_base_type);
×
731
        builder.add_container(ref_Y_container, ref_Y_type);
×
NEW
732
        auto ref_Y_offset =
×
NEW
733
            symbolic::mul(symbolic::add(symbolic::mul(this->output_channels_, n), symbolic::mul(out_channels, g)), spatial);
×
734
        auto& ref_Y_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
735
        {
×
736
            auto& Y_access = builder.add_access(ref_Y_block, access_Y->data(), access_Y->debug_info());
×
737
            auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, access_Y->debug_info());
×
NEW
738
            builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_offset}, ref_Y_type);
×
739
        }
×
740

741
        // GEMM: Y_ref[out_channels, spatial] = W_ref[out_channels, K_group] × patches[K_group, spatial]
742
        auto& gemm_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
743
        {
×
744
            auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
×
745
            auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
×
746
            auto& ref_W_access = builder.add_access(gemm_block, ref_W_container, access_W->debug_info());
×
747
            auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
×
748
            auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
×
749
            auto& ref_Y_access_out = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
×
750
            auto& libnode = builder.add_library_node<blas::GEMMNode>(
×
751
                gemm_block,
×
752
                this->debug_info(),
×
753
                blas::ImplementationType_BLAS,
×
NEW
754
                precision,
×
NEW
755
                blas::BLAS_Layout::RowMajor,
×
756
                blas::BLAS_Transpose::No, // transA
×
757
                blas::BLAS_Transpose::No, // transB
×
NEW
758
                out_channels, // m
×
NEW
759
                spatial, // n
×
760
                gemm_k, // k
×
761
                gemm_k, // lda
×
NEW
762
                spatial, // ldb
×
NEW
763
                spatial // ldc
×
764
            );
×
765
            builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
×
766
            builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
×
767
            builder
×
768
                .add_computational_memlet(gemm_block, ref_W_access, libnode, "__A", {}, ref_W_type, iedge_W->debug_info());
×
769
            builder.add_computational_memlet(
×
770
                gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info()
×
771
            );
×
772
            builder.add_computational_memlet(
×
773
                gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, oedge_Y->debug_info()
×
774
            );
×
775
            builder.add_computational_memlet(
×
776
                gemm_block, libnode, "__C", ref_Y_access_out, {}, ref_Y_type, oedge_Y->debug_info()
×
777
            );
×
778
        }
×
779

780
        // Add bias if available: Y[n, g*out_channels + l, o...] += B[g*out_channels + l]
781
        if (has_bias) {
×
UNCOV
782
            auto l_container = builder.find_new_name("_l");
×
783
            builder.add_container(l_container, indvar_type);
×
784
            auto l = symbolic::symbol(l_container);
×
785
            auto& loop_l = builder.add_map(
×
786
                loop_g.root(),
×
787
                l,
×
788
                symbolic::Lt(l, out_channels),
×
789
                symbolic::zero(),
×
790
                symbolic::add(l, symbolic::one()),
×
791
                ScheduleType_Sequential::create(),
×
792
                {},
×
793
                block->debug_info()
×
794
            );
×
NEW
795
            structured_control_flow::Sequence* bias_seq = &loop_l.root();
×
796

NEW
797
            symbolic::SymbolVec bias_os;
×
NEW
798
            bias_os.reserve(dims);
×
799
            for (size_t i = 0; i < dims; i++) {
×
800
                auto o_container = builder.find_new_name("_o");
×
801
                builder.add_container(o_container, indvar_type);
×
802
                auto o = symbolic::symbol(o_container);
×
NEW
803
                bias_os.push_back(o);
×
804
                auto& loop_o = builder.add_map(
×
NEW
805
                    *bias_seq,
×
806
                    o,
×
807
                    symbolic::Lt(o, out_shape[i]),
×
808
                    symbolic::zero(),
×
809
                    symbolic::add(o, symbolic::one()),
×
810
                    ScheduleType_Sequential::create(),
×
811
                    {},
×
812
                    block->debug_info()
×
813
                );
×
NEW
814
                bias_seq = &loop_o.root();
×
815
            }
×
816

UNCOV
817
            data_flow::Subset Y_subset;
×
818
            Y_subset.push_back(n);
×
819
            Y_subset.push_back(symbolic::add(symbolic::mul(out_channels, g), l));
×
NEW
820
            Y_subset.insert(Y_subset.end(), bias_os.begin(), bias_os.end());
×
821
            auto B_subset = symbolic::add(symbolic::mul(out_channels, g), l);
×
822

NEW
823
            auto& bias_block = builder.add_block(*bias_seq, {}, block->debug_info());
×
824
            {
×
825
                auto& B_access = builder.add_access(bias_block, access_B->data(), access_B->debug_info());
×
826
                auto& Y_access_in = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
827
                auto& Y_access_out = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
828
                auto& tasklet = builder.add_tasklet(
×
829
                    bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info()
×
830
                );
×
831
                builder.add_computational_memlet(
×
832
                    bias_block, Y_access_in, tasklet, "_in1", Y_subset, oedge_Y->base_type(), this->debug_info()
×
833
                );
×
834
                builder.add_computational_memlet(
×
835
                    bias_block, B_access, tasklet, "_in2", {B_subset}, iedge_B->base_type(), iedge_B->debug_info()
×
836
                );
×
837
                builder.add_computational_memlet(
×
838
                    bias_block, tasklet, "_out", Y_access_out, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
839
                );
×
840
            }
×
841
        }
×
842

843
        // Free patches
844
        auto& patches_free_block = builder.add_block(loop_g.root(), {}, block->debug_info());
×
845
        {
×
846
            auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
×
847
            auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
×
848
            auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
×
849
            builder.add_computational_memlet(
×
850
                patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
×
851
            );
×
852
            builder.add_computational_memlet(
×
853
                patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
×
854
            );
×
855
        }
×
856

857
        /* ===== Groups ========================================================================= */
858
    }
×
859

860
    // Clean up the original block
861
    builder.remove_memlet(*block, *iedge_X);
5✔
862
    builder.remove_memlet(*block, *iedge_W);
5✔
863
    if (has_bias) {
5✔
864
        builder.remove_memlet(*block, *iedge_B);
×
865
    }
×
866
    builder.remove_memlet(*block, *oedge_Y);
5✔
867
    builder.remove_node(*block, *access_X);
5✔
868
    builder.remove_node(*block, *access_W);
5✔
869
    if (has_bias) {
5✔
870
        builder.remove_node(*block, *access_B);
×
871
    }
×
872
    builder.remove_node(*block, *access_Y);
5✔
873
    builder.remove_node(*block, *this);
5✔
874
    builder.remove_child(*block_parent, block_index + 1);
5✔
875

876
    {
5✔
877
        // CPU-specific transformations on the canonical expansion:
878
        //
879
        // 1. LoopTiling on h_out inside im2col
880
        //    - Split the output spatial H-loop into tiles of H_TILE rows
881
        //    - Reduces patches buffer from K×spatial to K×(H_TILE×W_out)
882
        //    - Tile size chosen to fit patches + weights in L2 cache
883
        // 2. TileFusion of im2col + GEMM
884
        //    - Fuse the tiled im2col block with the per-tile GEMM
885
        //    - Each tile: extract patches → GEMM → next tile
886
        //    - Both weights A and patches B reside in L2 when GEMM runs
887
        //
888
        // 3. MapCollapse on batch × h_tile
889
        //    - Flatten (batch, h_tile) into a single parallel dimension
890
        //    - Maximizes parallel work items (e.g., 32 batches × 4 tiles = 128)
891
        //
892
        // 4. Malloc privatization
893
        //    - Move patches malloc/free inside the parallel region
894
        //    - Each thread allocates its tile buffer once, reuses across tiles
895
    }
5✔
896
    {
5✔
897
        // GPU-specific transformations on the canonical expansion:
898
        //
899
        // 1. LoopTiling on output spatial dimensions
900
        //    - Tile H_out and W_out for thread block mapping
901
        //    - Tile sizes chosen for shared memory capacity
902
        //
903
        // 2. Map outer dimensions to GPU grid
904
        //    - batch → grid.z, output_channel_tile → grid.y, spatial_tile → grid.x
905
        //    - Apply CUDAScheduler to outer maps
906
        //
907
        // 3. Shared memory for patches tile
908
        //    - Replace malloc with __shared__ memory allocation
909
        //    - Cooperative loading: threads in a block load patches tile together
910
        //
911
        // 4. Register tiling for GEMM accumulation
912
        //    - Each thread computes a small tile of the output
913
        //    - Use fp_fma tasklets for multiply-accumulate
914
        //
915
        // 5. Memory coalescing
916
        //    - Ensure im2col reads and output writes are coalesced
917
        //    - Thread indexing aligned with contiguous memory dimension
918
    }
5✔
919

920
    return true;
5✔
921
}
5✔
922

923
symbolic::SymbolSet ConvNode::symbols() const {
×
924
    symbolic::SymbolSet syms;
×
925

926
    for (auto& expr : shape_) {
×
927
        for (auto& atom : symbolic::atoms(expr)) {
×
928
            syms.insert(atom);
×
929
        }
×
930
    }
×
931
    for (auto& expr : kernel_shape_) {
×
932
        for (auto& atom : symbolic::atoms(expr)) {
×
933
            syms.insert(atom);
×
934
        }
×
935
    }
×
936
    for (auto& expr : strides_) {
×
937
        for (auto& atom : symbolic::atoms(expr)) {
×
938
            syms.insert(atom);
×
939
        }
×
940
    }
×
941
    for (auto& expr : pads_) {
×
942
        for (auto& atom : symbolic::atoms(expr)) {
×
943
            syms.insert(atom);
×
944
        }
×
945
    }
×
946
    for (auto& expr : dilations_) {
×
947
        for (auto& atom : symbolic::atoms(expr)) {
×
948
            syms.insert(atom);
×
949
        }
×
950
    }
×
951
    for (auto& atom : symbolic::atoms(output_channels_)) {
×
952
        syms.insert(atom);
×
953
    }
×
954
    for (auto& atom : symbolic::atoms(group_)) {
×
955
        syms.insert(atom);
×
956
    }
×
957

958
    return syms;
×
959
}
×
960

961
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
962
    for (auto& expr : shape_) {
×
963
        expr = symbolic::subs(expr, old_expression, new_expression);
×
964
    }
×
965
    for (auto& expr : kernel_shape_) {
×
966
        expr = symbolic::subs(expr, old_expression, new_expression);
×
967
    }
×
968
    for (auto& expr : strides_) {
×
969
        expr = symbolic::subs(expr, old_expression, new_expression);
×
970
    }
×
971
    for (auto& expr : pads_) {
×
972
        expr = symbolic::subs(expr, old_expression, new_expression);
×
973
    }
×
974
    for (auto& expr : dilations_) {
×
975
        expr = symbolic::subs(expr, old_expression, new_expression);
×
976
    }
×
977
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
978
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
979
}
×
980

981
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
982
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
983
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
984
        element_id,
1✔
985
        this->debug_info(),
1✔
986
        vertex,
1✔
987
        parent,
1✔
988
        shape_,
1✔
989
        kernel_shape_,
1✔
990
        strides_,
1✔
991
        pads_,
1✔
992
        dilations_,
1✔
993
        output_channels_,
1✔
994
        group_
1✔
995
    ));
1✔
996
}
1✔
997

998
std::string ConvNode::toStr() const {
×
999
    std::stringstream result;
×
1000
    result << "Conv(shape=[";
×
1001
    for (size_t i = 0; i < shape_.size(); ++i) {
×
1002
        if (i > 0) {
×
1003
            result << ", ";
×
1004
        }
×
1005
        result << shape_[i]->__str__();
×
1006
    }
×
1007
    result << "], kernel_shape=[";
×
1008
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
1009
        if (i > 0) {
×
1010
            result << ", ";
×
1011
        }
×
1012
        result << kernel_shape_[i]->__str__();
×
1013
    }
×
1014
    result << "], strides=[";
×
1015
    for (size_t i = 0; i < strides_.size(); ++i) {
×
1016
        if (i > 0) {
×
1017
            result << ", ";
×
1018
        }
×
1019
        result << strides_[i]->__str__();
×
1020
    }
×
1021
    result << "], pads=[";
×
1022
    for (size_t i = 0; i < pads_.size(); ++i) {
×
1023
        if (i > 0) {
×
1024
            result << ", ";
×
1025
        }
×
1026
        result << pads_[i]->__str__();
×
1027
    }
×
1028
    result << "], dilations=[";
×
1029
    for (size_t i = 0; i < dilations_.size(); ++i) {
×
1030
        if (i > 0) {
×
1031
            result << ", ";
×
1032
        }
×
1033
        result << dilations_[i]->__str__();
×
1034
    }
×
1035
    result << "], output_channels=" + output_channels_->__str__();
×
1036
    result << ", group=" + group_->__str__() + ")";
×
1037
    return result.str();
×
1038
}
×
1039

1040
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
1041
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
1042
    nlohmann::json j;
×
1043

1044
    j["code"] = conv_node.code().value();
×
1045

1046
    serializer::JSONSerializer serializer;
×
1047

1048
    j["shape"] = nlohmann::json::array();
×
1049
    for (auto& dim : conv_node.shape()) {
×
1050
        j["shape"].push_back(serializer.expression(dim));
×
1051
    }
×
1052

1053
    j["kernel_shape"] = nlohmann::json::array();
×
1054
    for (auto& dim : conv_node.kernel_shape()) {
×
1055
        j["kernel_shape"].push_back(serializer.expression(dim));
×
1056
    }
×
1057

1058
    j["strides"] = nlohmann::json::array();
×
1059
    for (auto& stride : conv_node.strides()) {
×
1060
        j["strides"].push_back(serializer.expression(stride));
×
1061
    }
×
1062

1063
    j["pads"] = nlohmann::json::array();
×
1064
    for (auto& pad : conv_node.pads()) {
×
1065
        j["pads"].push_back(serializer.expression(pad));
×
1066
    }
×
1067

1068
    j["dilations"] = nlohmann::json::array();
×
1069
    for (auto& dilation : conv_node.dilations()) {
×
1070
        j["dilations"].push_back(serializer.expression(dilation));
×
1071
    }
×
1072

1073
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
1074
    j["group"] = serializer.expression(conv_node.group());
×
1075

1076
    return j;
×
1077
}
×
1078

1079
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
1080
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
1081
) {
×
1082
    assert(j.contains("element_id"));
×
1083
    assert(j.contains("code"));
×
1084
    assert(j.contains("debug_info"));
×
1085
    assert(j.contains("kernel_shape"));
×
1086

1087
    std::vector<symbolic::Expression> shape;
×
1088
    if (j.contains("shape")) {
×
1089
        for (const auto& dim : j["shape"]) {
×
1090
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
1091
        }
×
1092
    }
×
1093

1094
    std::vector<symbolic::Expression> kernel_shape;
×
1095
    for (const auto& dim : j["kernel_shape"]) {
×
1096
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
1097
    }
×
1098

1099
    std::vector<symbolic::Expression> strides;
×
1100
    if (j.contains("strides")) {
×
1101
        for (const auto& stride : j["strides"]) {
×
1102
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
1103
        }
×
1104
    }
×
1105

1106
    std::vector<symbolic::Expression> pads;
×
1107
    if (j.contains("pads")) {
×
1108
        for (const auto& pad : j["pads"]) {
×
1109
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
1110
        }
×
1111
    }
×
1112

1113
    std::vector<symbolic::Expression> dilations;
×
1114
    if (j.contains("dilations")) {
×
1115
        for (const auto& dilation : j["dilations"]) {
×
1116
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
1117
        }
×
1118
    }
×
1119

1120
    symbolic::Expression output_channels = symbolic::one();
×
1121
    if (j.contains("output_channels")) {
×
1122
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
1123
    }
×
1124

1125
    symbolic::Expression group = symbolic::one();
×
1126
    if (j.contains("group")) {
×
1127
        group = symbolic::parse(j["group"].get<std::string>());
×
1128
    }
×
1129

1130
    sdfg::serializer::JSONSerializer serializer;
×
1131
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
1132

1133
    return builder.add_library_node<
×
1134
        ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, output_channels, group);
×
1135
}
×
1136

1137
} // namespace tensor
1138
} // namespace math
1139
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc