• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 25388891763

05 May 2026 04:28PM UTC coverage: 65.223%. First build
25388891763

Pull #663

github

web-flow
Merge 27a19dea9 into 8df8b842d
Pull Request #663: Fast Conv Node for CPU

221 of 235 new or added lines in 1 file covered. (94.04%)

31622 of 48483 relevant lines covered (65.22%)

2351.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.38
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4
#include <sstream>
5

6
#include "sdfg/analysis/analysis.h"
7
#include "sdfg/builder/structured_sdfg_builder.h"
8
#include "sdfg/data_flow/access_node.h"
9
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
10
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
11
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
12
#include "sdfg/data_flow/library_nodes/stdlib/memset.h"
13
#include "sdfg/data_flow/memlet.h"
14
#include "sdfg/data_flow/tasklet.h"
15
#include "sdfg/exceptions.h"
16
#include "sdfg/structured_control_flow/block.h"
17
#include "sdfg/structured_control_flow/map.h"
18
#include "sdfg/structured_control_flow/sequence.h"
19
#include "sdfg/symbolic/symbolic.h"
20
#include "sdfg/types/pointer.h"
21
#include "sdfg/types/scalar.h"
22
#include "sdfg/types/tensor.h"
23
#include "sdfg/types/type.h"
24

25
#include "sdfg/analysis/scope_analysis.h"
26
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
27
#include "symengine/integer.h"
28
#include "symengine/symengine_rcp.h"
29

30
namespace sdfg {
31
namespace math {
32
namespace tensor {
33

34
ConvNode::ConvNode(
35
    size_t element_id,
36
    const DebugInfo& debug_info,
37
    const graph::Vertex vertex,
38
    data_flow::DataFlowGraph& parent,
39
    const std::vector<symbolic::Expression>& shape,
40
    const std::vector<symbolic::Expression>& kernel_shape,
41
    const std::vector<symbolic::Expression>& strides,
42
    const std::vector<symbolic::Expression>& pads,
43
    const std::vector<symbolic::Expression>& dilations,
44
    symbolic::Expression output_channels,
45
    symbolic::Expression group
46
)
47
    : TensorNode(
33✔
48
          element_id,
33✔
49
          debug_info,
33✔
50
          vertex,
33✔
51
          parent,
33✔
52
          LibraryNodeType_Conv,
33✔
53
          {"Y"},
33✔
54
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
33✔
55
          data_flow::ImplementationType_NONE
33✔
56
      ),
33✔
57
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations),
33✔
58
      output_channels_(output_channels), group_(group) {}
33✔
59

60
void ConvNode::validate(const Function& function) const {
70✔
61
    TensorNode::validate(function);
70✔
62

63
    auto& graph = this->get_parent();
70✔
64

65
    // Custom validation for ConvNode that handles optional bias input
66
    // We expect X, W as required inputs and optionally B (bias)
67

68
    // Collect all input edges by connector name
69
    std::map<std::string, const data_flow::Memlet*> input_edges;
70✔
70
    for (auto& iedge : graph.in_edges(*this)) {
140✔
71
        input_edges[iedge.dst_conn()] = &iedge;
140✔
72
    }
140✔
73

74
    // Check that required inputs X and W are present
75
    if (input_edges.find("X") == input_edges.end()) {
70✔
76
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
77
    }
×
78
    if (input_edges.find("W") == input_edges.end()) {
70✔
79
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
80
    }
×
81

82
    // Validate that parameters are not empty
83
    if (shape_.empty()) {
70✔
84
        throw InvalidSDFGException("ConvNode shape cannot be empty");
×
85
    }
×
86
    if (kernel_shape_.empty()) {
70✔
87
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
88
    }
×
89
    if (strides_.empty()) {
70✔
90
        throw InvalidSDFGException("ConvNode strides cannot be empty");
×
91
    }
×
92
    if (pads_.empty()) {
70✔
93
        throw InvalidSDFGException("ConvNode pads cannot be empty");
×
94
    }
×
95
    if (dilations_.empty()) {
70✔
96
        throw InvalidSDFGException("ConvNode dilations cannot be empty");
×
97
    }
×
98

99
    // Validate consistent dimensions
100
    size_t spatial_dims = kernel_shape_.size();
70✔
101

102
    if (shape_.size() != spatial_dims + 2) {
70✔
103
        throw InvalidSDFGException("ConvNode shape must match kernel spatial dimensions + 2");
×
104
    }
×
105

106
    if (strides_.size() != spatial_dims) {
70✔
107
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
108
    }
1✔
109

110
    if (pads_.size() != 2 * spatial_dims) {
69✔
111
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
112
    }
1✔
113

114
    if (dilations_.size() != spatial_dims) {
68✔
115
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
116
    }
×
117

118
    // Validate groups
119
    if (SymEngine::is_a<SymEngine::Integer>(*this->group_)) {
68✔
120
        auto group_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->group_)->as_int();
68✔
121
        if (SymEngine::is_a<SymEngine::Integer>(*this->shape_[1])) {
68✔
122
            auto input_channels_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->shape_[1])->as_int();
68✔
123
            if (input_channels_int % group_int != 0) {
68✔
124
                throw InvalidSDFGException("ConvNode input channels must be divisible by groups");
×
125
            }
×
126
        }
68✔
127
        if (SymEngine::is_a<SymEngine::Integer>(*this->output_channels_)) {
68✔
128
            auto output_channels_int =
68✔
129
                SymEngine::rcp_static_cast<const SymEngine::Integer>(this->output_channels_)->as_int();
68✔
130
            if (output_channels_int % group_int != 0) {
68✔
131
                throw InvalidSDFGException("ConvNode output channels must be divisible by groups");
×
132
            }
×
133
        }
68✔
134
    }
68✔
135
}
68✔
136

137
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
138
    // Validate nodes are standalone in the data flow graph
139
    auto& dfg = this->get_parent();
5✔
140
    if ((dfg.nodes().size() != 4 || dfg.edges().size() != 3) && (dfg.nodes().size() != 5 || dfg.edges().size() != 4)) {
5✔
141
        return false;
×
142
    }
×
143

144
    // Get edges
145
    auto iedges = dfg.in_edges_by_connector(*this);
5✔
146
    auto oedges = dfg.out_edges_by_connector(*this);
5✔
147
    if (iedges.size() != 3 || oedges.size() != 1) {
5✔
148
        return false;
×
149
    }
×
150
    auto* iedge_X = iedges.at(0);
5✔
151
    auto* iedge_W = iedges.at(1);
5✔
152
    auto* iedge_B = iedges.at(2);
5✔
153
    auto* oedge_Y = oedges.at(0);
5✔
154
    if (!iedge_X || !iedge_W || !oedge_Y) {
5✔
155
        return false;
×
156
    }
×
157
    bool has_bias = iedge_B != nullptr;
5✔
158

159
    // Get access nodes
160
    auto* access_X = dynamic_cast<data_flow::AccessNode*>(&iedge_X->src());
5✔
161
    auto* access_W = dynamic_cast<data_flow::AccessNode*>(&iedge_W->src());
5✔
162
    auto* access_B = (has_bias ? dynamic_cast<data_flow::AccessNode*>(&iedge_B->src()) : nullptr);
5✔
163
    auto* access_Y = dynamic_cast<data_flow::AccessNode*>(&oedge_Y->dst());
5✔
164
    if (!access_X || !access_W || (has_bias && !access_B) || !access_Y) {
5✔
165
        return false;
×
166
    }
×
167

168
    // Get block & its parent
169
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
5✔
170
    if (!block) {
5✔
171
        return false;
×
172
    }
×
173
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
174
    auto* block_parent = dynamic_cast<structured_control_flow::Sequence*>(scope_analysis.parent_scope(block));
5✔
175
    if (!block_parent) {
5✔
176
        return false;
×
177
    }
×
178
    size_t block_index = block_parent->index(*block);
5✔
179
    if (block_index >= block_parent->size()) {
5✔
180
        return false;
×
181
    }
×
182

183
    // Determine BLAS precision
184
    blas::BLAS_Precision precision;
5✔
185
    types::Scalar base_type(this->primitive_type(dfg));
5✔
186
    switch (base_type.primitive_type()) {
5✔
187
        case types::PrimitiveType::Half:
×
188
            precision = blas::BLAS_Precision::h;
×
189
            break;
×
190
        case types::PrimitiveType::Float:
5✔
191
            precision = blas::BLAS_Precision::s;
5✔
192
            break;
5✔
193
        case types::PrimitiveType::Double:
×
194
            precision = blas::BLAS_Precision::d;
×
195
            break;
×
196
        default:
×
197
            return false;
×
198
    }
5✔
199

200
    // Create new sequence for expansion
201
    auto& new_sequence = builder.add_sequence_before(
5✔
202
        *block_parent, *block, block_parent->at(block_index).second.assignments(), block->debug_info()
5✔
203
    );
5✔
204

205
    // Dimensions, i.e., 1D, 2D, 3D, ...
206
    size_t dims = this->kernel_shape_.size();
5✔
207
    symbolic::MultiExpression out_shape;
5✔
208
    out_shape.reserve(dims);
5✔
209
    // out_shape[i] = (shape[i + 2] + pads[i] + pads[dims + i] - dilations[i] * (kernel_shape[i] - 1) - 1)
210
    //                 / strides[i] + 1
211
    for (size_t i = 0; i < dims; i++) {
15✔
212
        out_shape.push_back(symbolic::add(
10✔
213
            symbolic::div(
10✔
214
                symbolic::sub(
10✔
215
                    symbolic::
10✔
216
                        sub(symbolic::add(this->shape_[i + 2], symbolic::add(this->pads_[i], this->pads_[dims + i])),
10✔
217
                            symbolic::mul(this->dilations_[i], symbolic::sub(this->kernel_shape_[i], symbolic::one()))),
10✔
218
                    symbolic::one()
10✔
219
                ),
10✔
220
                this->strides_[i]
10✔
221
            ),
10✔
222
            symbolic::one()
10✔
223
        ));
10✔
224
    }
10✔
225
    types::Scalar indvar_type(types::PrimitiveType::Int64);
5✔
226

227
    auto in_channels = symbolic::div(this->shape_[1], this->group_);
5✔
228
    auto out_channels = symbolic::div(this->output_channels_, this->group_);
5✔
229

230
    // Add loop over batch size
231
    auto n_container = builder.find_new_name("_n");
5✔
232
    builder.add_container(n_container, indvar_type);
5✔
233
    auto n = symbolic::symbol(n_container);
5✔
234
    auto& loop_n = builder.add_map(
5✔
235
        new_sequence,
5✔
236
        n,
5✔
237
        symbolic::Lt(n, this->shape_[0]),
5✔
238
        symbolic::zero(),
5✔
239
        symbolic::add(n, symbolic::one()),
5✔
240
        ScheduleType_Sequential::create(),
5✔
241
        {},
5✔
242
        block->debug_info()
5✔
243
    );
5✔
244

245
    // Add loop over groups
246
    auto g_container = builder.find_new_name("_g");
5✔
247
    builder.add_container(g_container, indvar_type);
5✔
248
    auto g = symbolic::symbol(g_container);
5✔
249
    auto& loop_g = builder.add_map(
5✔
250
        loop_n.root(),
5✔
251
        g,
5✔
252
        symbolic::Lt(g, this->group_),
5✔
253
        symbolic::zero(),
5✔
254
        symbolic::add(g, symbolic::one()),
5✔
255
        ScheduleType_Sequential::create(),
5✔
256
        {},
5✔
257
        block->debug_info()
5✔
258
    );
5✔
259

260
    // Add patches container with malloc
261
    symbolic::Expression patches_size = in_channels;
5✔
262
    for (size_t i = 0; i < dims; i++) {
15✔
263
        patches_size = symbolic::mul(patches_size, symbolic::mul(this->kernel_shape_[i], out_shape[i]));
10✔
264
    }
10✔
265
    types::Pointer patches_type(base_type);
5✔
266
    auto patches_container = builder.find_new_name("_patches");
5✔
267
    builder.add_container(patches_container, patches_type);
5✔
268
    auto& patches_malloc_block = builder.add_block(loop_g.root(), {}, block->debug_info());
5✔
269
    {
5✔
270
        auto& patches_access = builder.add_access(patches_malloc_block, patches_container, this->debug_info());
5✔
271
        auto& libnode = builder.add_library_node<stdlib::MallocNode>(
5✔
272
            patches_malloc_block, this->debug_info(), symbolic::mul(patches_size, symbolic::size_of_type(base_type))
5✔
273
        );
5✔
274
        builder.add_computational_memlet(
5✔
275
            patches_malloc_block, libnode, "_ret", patches_access, {}, patches_type, this->debug_info()
5✔
276
        );
5✔
277
    }
5✔
278

279
    // Add loop over channels
280
    structured_control_flow::Sequence* current_seq = &loop_g.root();
5✔
281
    auto c_container = builder.find_new_name("_c");
5✔
282
    builder.add_container(c_container, indvar_type);
5✔
283
    auto c = symbolic::symbol(c_container);
5✔
284
    auto& loop_c = builder.add_map(
5✔
285
        *current_seq,
5✔
286
        c,
5✔
287
        symbolic::Lt(c, in_channels),
5✔
288
        symbolic::zero(),
5✔
289
        symbolic::add(c, symbolic::one()),
5✔
290
        ScheduleType_Sequential::create(),
5✔
291
        {},
5✔
292
        block->debug_info()
5✔
293
    );
5✔
294
    current_seq = &loop_c.root();
5✔
295

296
    // Add loops over kernel shape
297
    symbolic::SymbolVec ks;
5✔
298
    ks.reserve(dims);
5✔
299
    for (size_t i = 0; i < dims; i++) {
15✔
300
        auto k_container = builder.find_new_name("_k");
10✔
301
        builder.add_container(k_container, indvar_type);
10✔
302
        auto k = symbolic::symbol(k_container);
10✔
303
        ks.push_back(k);
10✔
304
        auto& loop_k = builder.add_map(
10✔
305
            *current_seq,
10✔
306
            k,
10✔
307
            symbolic::Lt(k, this->kernel_shape_[i]),
10✔
308
            symbolic::zero(),
10✔
309
            symbolic::add(k, symbolic::one()),
10✔
310
            ScheduleType_Sequential::create(),
10✔
311
            {},
10✔
312
            block->debug_info()
10✔
313
        );
10✔
314
        current_seq = &loop_k.root();
10✔
315
    }
10✔
316

317
    // Add loops over output dimensions
318
    symbolic::SymbolVec os;
5✔
319
    os.reserve(dims);
5✔
320
    for (size_t i = 0; i < dims; i++) {
15✔
321
        auto o_container = builder.find_new_name("_o");
10✔
322
        builder.add_container(o_container, indvar_type);
10✔
323
        auto o = symbolic::symbol(o_container);
10✔
324
        os.push_back(o);
10✔
325
        auto& loop_o = builder.add_map(
10✔
326
            *current_seq,
10✔
327
            o,
10✔
328
            symbolic::Lt(o, out_shape[i]),
10✔
329
            symbolic::zero(),
10✔
330
            symbolic::add(o, symbolic::one()),
10✔
331
            ScheduleType_Sequential::create(),
10✔
332
            {},
10✔
333
            block->debug_info()
10✔
334
        );
10✔
335
        current_seq = &loop_o.root();
10✔
336
    }
10✔
337

338
    // Add if/else to stay in bounds for copying
339
    symbolic::MultiExpression is;
5✔
340
    is.reserve(dims);
5✔
341
    symbolic::Condition copy_condition = symbolic::__true__();
5✔
342
    symbolic::Condition zero_condition = symbolic::__false__();
5✔
343
    for (size_t i = 0; i < dims; i++) {
15✔
344
        auto i_expr = symbolic::
10✔
345
            add(symbolic::sub(symbolic::mul(os[i], this->strides_[i]), this->pads_[i]),
10✔
346
                symbolic::mul(ks[i], this->dilations_[i]));
10✔
347
        is.push_back(i_expr);
10✔
348
        copy_condition = symbolic::
10✔
349
            And(copy_condition,
10✔
350
                symbolic::And(symbolic::Lt(i_expr, this->shape_[i + 2]), symbolic::Ge(i_expr, symbolic::zero())));
10✔
351
        zero_condition = symbolic::
10✔
352
            Or(zero_condition,
10✔
353
               symbolic::Or(symbolic::Ge(i_expr, this->shape_[i + 2]), symbolic::Lt(i_expr, symbolic::zero())));
10✔
354
    }
10✔
355
    auto& branch = builder.add_if_else(*current_seq, {}, block->debug_info());
5✔
356
    auto& copy_case = builder.add_case(branch, copy_condition, block->debug_info());
5✔
357
    auto& zero_case = builder.add_case(branch, zero_condition, block->debug_info());
5✔
358

359
    // Determine patches subset & tensor type
360
    data_flow::Subset patches_subset;
5✔
361
    patches_subset.push_back(c);
5✔
362
    patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
5✔
363
    patches_subset.insert(patches_subset.end(), os.begin(), os.end());
5✔
364
    symbolic::MultiExpression patches_shape;
5✔
365
    patches_shape.push_back(in_channels);
5✔
366
    patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
5✔
367
    patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
5✔
368
    types::Tensor patches_tensor_type(base_type, patches_shape);
5✔
369

370
    // Determine subset for X
371
    data_flow::Subset subset_X;
5✔
372
    subset_X.push_back(n);
5✔
373
    subset_X.push_back(symbolic::add(symbolic::mul(in_channels, g), c));
5✔
374
    subset_X.insert(subset_X.end(), is.begin(), is.end());
5✔
375

376
    // Add copy from X to patches
377
    auto& copy_block = builder.add_block(copy_case, {}, block->debug_info());
5✔
378
    {
5✔
379
        auto& X_access = builder.add_access(copy_block, access_X->data(), access_X->debug_info());
5✔
380
        auto& patches_access = builder.add_access(copy_block, patches_container, this->debug_info());
5✔
381
        auto& tasklet =
5✔
382
            builder.add_tasklet(copy_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
383
        builder.add_computational_memlet(
5✔
384
            copy_block, X_access, tasklet, "_in", subset_X, iedge_X->base_type(), iedge_X->debug_info()
5✔
385
        );
5✔
386
        builder.add_computational_memlet(
5✔
387
            copy_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
388
        );
5✔
389
    }
5✔
390

391
    // Add zero assignment to patches
392
    auto& zero_block = builder.add_block(zero_case, {}, block->debug_info());
5✔
393
    {
5✔
394
        auto& constant_zero = builder.add_constant(zero_block, "0.0", base_type, this->debug_info());
5✔
395
        auto& patches_access = builder.add_access(zero_block, patches_container, this->debug_info());
5✔
396
        auto& tasklet =
5✔
397
            builder.add_tasklet(zero_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
398
        builder.add_computational_memlet(zero_block, constant_zero, tasklet, "_in", {}, base_type, this->debug_info());
5✔
399
        builder.add_computational_memlet(
5✔
400
            zero_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
401
        );
5✔
402
    }
5✔
403

404
    // Add reference to W
405
    auto ref_W_container = builder.find_new_name("_ref_W");
5✔
406
    types::Scalar ref_W_base_type(builder.subject().type(access_W->data()).primitive_type());
5✔
407
    types::Pointer ref_W_type(ref_W_base_type);
5✔
408
    builder.add_container(ref_W_container, ref_W_type);
5✔
409
    auto ref_W_subset = symbolic::mul(symbolic::mul(out_channels, g), in_channels);
5✔
410
    for (size_t i = 0; i < dims; i++) {
15✔
411
        ref_W_subset = symbolic::mul(ref_W_subset, this->kernel_shape_[i]);
10✔
412
    }
10✔
413
    auto& ref_W_block = builder.add_block(loop_g.root(), {}, block->debug_info());
5✔
414
    {
5✔
415
        auto& W_access = builder.add_access(ref_W_block, access_W->data(), access_W->debug_info());
5✔
416
        auto& ref_W_access = builder.add_access(ref_W_block, ref_W_container, access_W->debug_info());
5✔
417
        builder.add_reference_memlet(ref_W_block, W_access, ref_W_access, {ref_W_subset}, ref_W_type);
5✔
418
    }
5✔
419

420
    // Add reference to Y
421
    auto ref_Y_container = builder.find_new_name("_ref_Y");
5✔
422
    types::Scalar ref_Y_base_type(builder.subject().type(access_Y->data()).primitive_type());
5✔
423
    types::Pointer ref_Y_type(ref_Y_base_type);
5✔
424
    builder.add_container(ref_Y_container, ref_Y_type);
5✔
425
    auto ref_Y_subset = symbolic::add(symbolic::mul(this->output_channels_, n), symbolic::mul(out_channels, g));
5✔
426
    for (size_t i = 0; i < dims; i++) {
15✔
427
        ref_Y_subset = symbolic::mul(ref_Y_subset, out_shape[i]);
10✔
428
    }
10✔
429
    auto& ref_Y_block = builder.add_block(loop_g.root(), {}, block->debug_info());
5✔
430
    {
5✔
431
        auto& Y_access = builder.add_access(ref_Y_block, access_Y->data(), access_Y->debug_info());
5✔
432
        auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, access_Y->debug_info());
5✔
433
        builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_subset}, ref_Y_type);
5✔
434
    }
5✔
435

436
    // Add GEMM node
437
    auto& gemm_block = builder.add_block(loop_g.root(), {}, block->debug_info());
5✔
438
    {
5✔
439
        auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
5✔
440
        auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
5✔
441
        auto& ref_W_access = builder.add_access(gemm_block, ref_W_container, access_W->debug_info());
5✔
442
        auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
5✔
443
        auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
5✔
444
        auto& ref_Y_access_out = builder.add_access(gemm_block, ref_Y_container, access_Y->debug_info());
5✔
445
        symbolic::Expression gemm_m = out_channels;
5✔
446
        symbolic::Expression gemm_n = symbolic::one();
5✔
447
        symbolic::Expression gemm_k = in_channels;
5✔
448
        for (size_t i = 0; i < dims; i++) {
15✔
449
            gemm_n = symbolic::mul(gemm_n, out_shape[i]);
10✔
450
            gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
10✔
451
        }
10✔
452
        auto& libnode = builder.add_library_node<blas::GEMMNode>(
5✔
453
            gemm_block,
5✔
454
            this->debug_info(),
5✔
455
            blas::ImplementationType_BLAS,
5✔
456
            precision, // precision
5✔
457
            blas::BLAS_Layout::RowMajor, // layout
5✔
458
            blas::BLAS_Transpose::No, // transA
5✔
459
            blas::BLAS_Transpose::No, // transB
5✔
460
            gemm_m, // m
5✔
461
            gemm_n, // n
5✔
462
            gemm_k, // k
5✔
463
            gemm_k, // lda
5✔
464
            gemm_n, // ldb
5✔
465
            gemm_n // ldc
5✔
466
        );
5✔
467
        builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
5✔
468
        builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
5✔
469
        builder
5✔
470
            .add_computational_memlet(gemm_block, ref_W_access, libnode, "__A", {}, ref_W_type, iedge_W->debug_info());
5✔
471
        builder
5✔
472
            .add_computational_memlet(gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info());
5✔
473
        builder
5✔
474
            .add_computational_memlet(gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, oedge_Y->debug_info());
5✔
475
        builder
5✔
476
            .add_computational_memlet(gemm_block, libnode, "__C", ref_Y_access_out, {}, ref_Y_type, oedge_Y->debug_info());
5✔
477
    }
5✔
478

479
    // Add bias if available
480
    if (has_bias) {
5✔
481
        // Add loop over output channels
482
        auto l_container = builder.find_new_name("_l");
×
483
        builder.add_container(l_container, indvar_type);
×
484
        auto l = symbolic::symbol(l_container);
×
485
        auto& loop_l = builder.add_map(
×
NEW
486
            loop_g.root(),
×
487
            l,
×
NEW
488
            symbolic::Lt(l, out_channels),
×
489
            symbolic::zero(),
×
490
            symbolic::add(l, symbolic::one()),
×
491
            ScheduleType_Sequential::create(),
×
492
            {},
×
493
            block->debug_info()
×
494
        );
×
495
        current_seq = &loop_l.root();
×
496

497
        // Add loops over output dimensions (again)
498
        for (size_t i = 0; i < dims; i++) {
×
499
            auto o_container = builder.find_new_name("_o");
×
500
            builder.add_container(o_container, indvar_type);
×
501
            auto o = symbolic::symbol(o_container);
×
502
            auto& loop_o = builder.add_map(
×
503
                *current_seq,
×
504
                o,
×
505
                symbolic::Lt(o, out_shape[i]),
×
506
                symbolic::zero(),
×
507
                symbolic::add(o, symbolic::one()),
×
508
                ScheduleType_Sequential::create(),
×
509
                {},
×
510
                block->debug_info()
×
511
            );
×
512
            current_seq = &loop_o.root();
×
513
            os[i] = o;
×
514
        }
×
515

516
        // Add bias to Y
517
        data_flow::Subset Y_subset;
×
518
        Y_subset.push_back(n);
×
NEW
519
        Y_subset.push_back(symbolic::add(symbolic::mul(out_channels, g), l));
×
520
        Y_subset.insert(Y_subset.end(), os.begin(), os.end());
×
NEW
521
        auto B_subset = symbolic::add(symbolic::mul(out_channels, g), l);
×
NEW
522
        auto& bias_block = builder.add_block(*current_seq, {}, block->debug_info());
×
NEW
523
        {
×
NEW
524
            auto& B_access = builder.add_access(bias_block, access_B->data(), access_B->debug_info());
×
NEW
525
            auto& Y_access_in = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
NEW
526
            auto& Y_access_out = builder.add_access(bias_block, access_Y->data(), access_Y->debug_info());
×
527
            auto& tasklet =
×
528
                builder
×
NEW
529
                    .add_tasklet(bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info());
×
530
            builder.add_computational_memlet(
×
NEW
531
                bias_block, Y_access_in, tasklet, "_in1", Y_subset, oedge_Y->base_type(), this->debug_info()
×
532
            );
×
533
            builder.add_computational_memlet(
×
NEW
534
                bias_block, B_access, tasklet, "_in2", {B_subset}, iedge_B->base_type(), iedge_B->debug_info()
×
535
            );
×
536
            builder.add_computational_memlet(
×
NEW
537
                bias_block, tasklet, "_out", Y_access_out, Y_subset, oedge_Y->base_type(), oedge_Y->debug_info()
×
538
            );
×
539
        }
×
NEW
540
    }
×
541

542
    // Add free for patches container
543
    auto& patches_free_block = builder.add_block(loop_g.root(), {}, block->debug_info());
5✔
544
    {
5✔
545
        auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
546
        auto& patches_access_out = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
547
        auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
5✔
548
        builder.add_computational_memlet(
5✔
549
            patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
5✔
550
        );
5✔
551
        builder.add_computational_memlet(
5✔
552
            patches_free_block, libnode, "_ptr", patches_access_out, {}, patches_type, this->debug_info()
5✔
553
        );
5✔
554
    }
5✔
555

556
    // Clean up the original block
557
    builder.remove_memlet(*block, *iedge_X);
5✔
558
    builder.remove_memlet(*block, *iedge_W);
5✔
559
    if (has_bias) {
5✔
560
        builder.remove_memlet(*block, *iedge_B);
×
561
    }
×
562
    builder.remove_memlet(*block, *oedge_Y);
5✔
563
    builder.remove_node(*block, *access_X);
5✔
564
    builder.remove_node(*block, *access_W);
5✔
565
    if (has_bias) {
5✔
566
        builder.remove_node(*block, *access_B);
×
567
    }
×
568
    builder.remove_node(*block, *access_Y);
5✔
569
    builder.remove_node(*block, *this);
5✔
570
    builder.remove_child(*block_parent, block_index + 1);
5✔
571

572
    return true;
5✔
573
}
5✔
574

575
symbolic::SymbolSet ConvNode::symbols() const {
13✔
576
    symbolic::SymbolSet syms;
13✔
577

578
    for (auto& expr : shape_) {
52✔
579
        for (auto& atom : symbolic::atoms(expr)) {
52✔
580
            syms.insert(atom);
×
581
        }
×
582
    }
52✔
583
    for (auto& expr : kernel_shape_) {
26✔
584
        for (auto& atom : symbolic::atoms(expr)) {
26✔
585
            syms.insert(atom);
×
586
        }
×
587
    }
26✔
588
    for (auto& expr : strides_) {
26✔
589
        for (auto& atom : symbolic::atoms(expr)) {
26✔
590
            syms.insert(atom);
×
591
        }
×
592
    }
26✔
593
    for (auto& expr : pads_) {
52✔
594
        for (auto& atom : symbolic::atoms(expr)) {
52✔
595
            syms.insert(atom);
×
596
        }
×
597
    }
52✔
598
    for (auto& expr : dilations_) {
26✔
599
        for (auto& atom : symbolic::atoms(expr)) {
26✔
600
            syms.insert(atom);
×
601
        }
×
602
    }
26✔
603
    for (auto& atom : symbolic::atoms(output_channels_)) {
13✔
604
        syms.insert(atom);
×
605
    }
×
606
    for (auto& atom : symbolic::atoms(group_)) {
13✔
607
        syms.insert(atom);
×
608
    }
×
609

610
    return syms;
13✔
611
}
13✔
612

613
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
614
    for (auto& expr : shape_) {
×
615
        expr = symbolic::subs(expr, old_expression, new_expression);
×
616
    }
×
617
    for (auto& expr : kernel_shape_) {
×
618
        expr = symbolic::subs(expr, old_expression, new_expression);
×
619
    }
×
620
    for (auto& expr : strides_) {
×
621
        expr = symbolic::subs(expr, old_expression, new_expression);
×
622
    }
×
623
    for (auto& expr : pads_) {
×
624
        expr = symbolic::subs(expr, old_expression, new_expression);
×
625
    }
×
626
    for (auto& expr : dilations_) {
×
627
        expr = symbolic::subs(expr, old_expression, new_expression);
×
628
    }
×
629
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
630
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
631
}
×
632

633
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
634
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
635
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
636
        element_id,
1✔
637
        this->debug_info(),
1✔
638
        vertex,
1✔
639
        parent,
1✔
640
        shape_,
1✔
641
        kernel_shape_,
1✔
642
        strides_,
1✔
643
        pads_,
1✔
644
        dilations_,
1✔
645
        output_channels_,
1✔
646
        group_
1✔
647
    ));
1✔
648
}
1✔
649

650
std::string ConvNode::toStr() const {
×
651
    std::stringstream result;
×
652
    result << "Conv(shape=[";
×
653
    for (size_t i = 0; i < shape_.size(); ++i) {
×
654
        if (i > 0) {
×
655
            result << ", ";
×
656
        }
×
657
        result << shape_[i]->__str__();
×
658
    }
×
659
    result << "], kernel_shape=[";
×
660
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
661
        if (i > 0) {
×
662
            result << ", ";
×
663
        }
×
664
        result << kernel_shape_[i]->__str__();
×
665
    }
×
666
    result << "], strides=[";
×
667
    for (size_t i = 0; i < strides_.size(); ++i) {
×
668
        if (i > 0) {
×
669
            result << ", ";
×
670
        }
×
671
        result << strides_[i]->__str__();
×
672
    }
×
673
    result << "], pads=[";
×
674
    for (size_t i = 0; i < pads_.size(); ++i) {
×
675
        if (i > 0) {
×
676
            result << ", ";
×
677
        }
×
678
        result << pads_[i]->__str__();
×
679
    }
×
680
    result << "], dilations=[";
×
681
    for (size_t i = 0; i < dilations_.size(); ++i) {
×
682
        if (i > 0) {
×
683
            result << ", ";
×
684
        }
×
685
        result << dilations_[i]->__str__();
×
686
    }
×
687
    result << "], output_channels=" + output_channels_->__str__();
×
688
    result << ", group=" + group_->__str__() + ")";
×
689
    return result.str();
×
690
}
×
691

692
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
693
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
694
    nlohmann::json j;
×
695

696
    j["code"] = conv_node.code().value();
×
697

698
    serializer::JSONSerializer serializer;
×
699

700
    j["shape"] = nlohmann::json::array();
×
701
    for (auto& dim : conv_node.shape()) {
×
702
        j["shape"].push_back(serializer.expression(dim));
×
703
    }
×
704

705
    j["kernel_shape"] = nlohmann::json::array();
×
706
    for (auto& dim : conv_node.kernel_shape()) {
×
707
        j["kernel_shape"].push_back(serializer.expression(dim));
×
708
    }
×
709

710
    j["strides"] = nlohmann::json::array();
×
711
    for (auto& stride : conv_node.strides()) {
×
712
        j["strides"].push_back(serializer.expression(stride));
×
713
    }
×
714

715
    j["pads"] = nlohmann::json::array();
×
716
    for (auto& pad : conv_node.pads()) {
×
717
        j["pads"].push_back(serializer.expression(pad));
×
718
    }
×
719

720
    j["dilations"] = nlohmann::json::array();
×
721
    for (auto& dilation : conv_node.dilations()) {
×
722
        j["dilations"].push_back(serializer.expression(dilation));
×
723
    }
×
724

725
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
726
    j["group"] = serializer.expression(conv_node.group());
×
727

728
    return j;
×
729
}
×
730

731
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
732
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
733
) {
×
734
    assert(j.contains("element_id"));
×
735
    assert(j.contains("code"));
×
736
    assert(j.contains("debug_info"));
×
737
    assert(j.contains("kernel_shape"));
×
738

739
    std::vector<symbolic::Expression> shape;
×
740
    if (j.contains("shape")) {
×
741
        for (const auto& dim : j["shape"]) {
×
742
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
743
        }
×
744
    }
×
745

746
    std::vector<symbolic::Expression> kernel_shape;
×
747
    for (const auto& dim : j["kernel_shape"]) {
×
748
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
749
    }
×
750

751
    std::vector<symbolic::Expression> strides;
×
752
    if (j.contains("strides")) {
×
753
        for (const auto& stride : j["strides"]) {
×
754
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
755
        }
×
756
    }
×
757

758
    std::vector<symbolic::Expression> pads;
×
759
    if (j.contains("pads")) {
×
760
        for (const auto& pad : j["pads"]) {
×
761
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
762
        }
×
763
    }
×
764

765
    std::vector<symbolic::Expression> dilations;
×
766
    if (j.contains("dilations")) {
×
767
        for (const auto& dilation : j["dilations"]) {
×
768
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
769
        }
×
770
    }
×
771

772
    symbolic::Expression output_channels = symbolic::one();
×
773
    if (j.contains("output_channels")) {
×
774
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
775
    }
×
776

777
    symbolic::Expression group = symbolic::one();
×
778
    if (j.contains("group")) {
×
779
        group = symbolic::parse(j["group"].get<std::string>());
×
780
    }
×
781

782
    sdfg::serializer::JSONSerializer serializer;
×
783
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
784

785
    return builder.add_library_node<
×
786
        ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, output_channels, group);
×
787
}
×
788

789
} // namespace tensor
790
} // namespace math
791
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc