• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 28158800507

25 Jun 2026 08:57AM UTC coverage: 61.644% (+0.06%) from 61.582%
28158800507

push

github

web-flow
MapFusionByDomain (#771)

 + New Map fusion caches data about iteration domain and map candidates
 + only matches up iteration domain exactly, per loop level.
 + Can support fusing non-leaf stacks of loops (stack ends where the shallower stack stops being perfectly nested & parallel)
 + new Element::replace for bulk replacements
 + New PatternMatcher visitor supports descending into replaced or modified nodes to allow for single-pass nested loop fusings
 + LoopAnalysis can now be kept up-to-date with changes done by Map-fusion
 + unit tests for the updating of LoopAnalysis
 * updated LoopAnalysis to be easier to keep up-to-date with changes. LoopTree is no longer ordered, if you want to iterate in pre-order, use the specific method for that
 + convenience StructuredSDFGBuilder.remove_from_parent()
 + RedundantLoadElim pass to skip reading from memory locations that have just been written (same block). Fusing no longer needs to do this
     RedundantLoadElimination does a simple check for other writes to the same structure. Can skip writes if redundant or not modify, if their are writes to different indices
* Updated verifiers to match new fusion
~ moved verifier checks behind correctness checks in npbench harness. Its more critical if we do not even get the expected results
* Added MapFusionByDomain also to loop-norm stage (currently inactive, causes more kernels that currently cannot be safely offloaded to CUDA.
---------

Co-authored-by: Lukas Truemper <lukas.truemper@outlook.de>

771 of 1186 new or added lines in 55 files covered. (65.01%)

6 existing lines in 6 files now uncovered.

38302 of 62134 relevant lines covered (61.64%)

987.24 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.38
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4
#include <sstream>
5
#include <utility>
6

7
#include "sdfg/analysis/analysis.h"
8
#include "sdfg/builder/structured_sdfg_builder.h"
9
#include "sdfg/data_flow/access_node.h"
10
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
11
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
12
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
13
#include "sdfg/data_flow/library_nodes/stdlib/memset.h"
14
#include "sdfg/data_flow/memlet.h"
15
#include "sdfg/data_flow/tasklet.h"
16
#include "sdfg/exceptions.h"
17
#include "sdfg/structured_control_flow/block.h"
18
#include "sdfg/structured_control_flow/map.h"
19
#include "sdfg/structured_control_flow/sequence.h"
20
#include "sdfg/symbolic/symbolic.h"
21
#include "sdfg/types/pointer.h"
22
#include "sdfg/types/scalar.h"
23
#include "sdfg/types/tensor.h"
24
#include "sdfg/types/type.h"
25

26
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
27
#include "symengine/integer.h"
28
#include "symengine/symengine_rcp.h"
29

30
namespace sdfg {
31
namespace math {
32
namespace tensor {
33

34
ConvNode::ConvNode(
35
    size_t element_id,
36
    const DebugInfo& debug_info,
37
    const graph::Vertex vertex,
38
    data_flow::DataFlowGraph& parent,
39
    const std::vector<symbolic::Expression>& shape,
40
    const std::vector<symbolic::Expression>& kernel_shape,
41
    const std::vector<symbolic::Expression>& strides,
42
    const std::vector<symbolic::Expression>& pads,
43
    const std::vector<symbolic::Expression>& dilations,
44
    symbolic::Expression output_channels,
45
    symbolic::Expression group,
46
    bool with_bias,
47
    QuantizationType quantization,
48
    const data_flow::ImplementationType& impl_type
49
)
50
    : SpatialTensorNode(
47✔
51
          element_id,
47✔
52
          debug_info,
47✔
53
          vertex,
47✔
54
          parent,
47✔
55
          LibraryNodeType_Conv,
47✔
56
          {},
47✔
57
          {"Y", "X", "W"}, // X and W are required, B (bias) is optional
47✔
58
          impl_type,
47✔
59
          quantization,
47✔
60
          shape,
47✔
61
          kernel_shape,
47✔
62
          strides,
47✔
63
          pads,
47✔
64
          dilations
47✔
65
      ),
47✔
66
      output_channels_(std::move(output_channels)), group_(std::move(group)), with_bias_(with_bias) {
47✔
67
    if (with_bias) {
47✔
68
        inputs_.push_back("B");
5✔
69
    }
5✔
70
}
47✔
71

72
void ConvNode::validate(const Function& function) const {
82✔
73
    TensorNode::validate(function);
82✔
74

75
    auto& graph = this->get_parent();
82✔
76

77
    // Custom validation for ConvNode that handles optional bias input
78
    // We expect X, W as required inputs and optionally B (bias)
79

80
    // Collect all input edges by connector name
81
    std::map<std::string, const data_flow::Memlet*> input_edges;
82✔
82
    for (auto& iedge : graph.in_edges(*this)) {
250✔
83
        input_edges[iedge.dst_conn()] = &iedge;
250✔
84
    }
250✔
85

86
    // Check that required inputs X and W are present
87
    if (input_edges.find("X") == input_edges.end()) {
82✔
88
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
89
    }
×
90
    if (input_edges.find("W") == input_edges.end()) {
82✔
91
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
92
    }
×
93

94
    // Validate that parameters are not empty
95
    if (shape_.empty()) {
82✔
96
        throw InvalidSDFGException("ConvNode shape cannot be empty");
×
97
    }
×
98
    if (kernel_shape_.empty()) {
82✔
99
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
100
    }
×
101
    if (strides_.empty()) {
82✔
102
        throw InvalidSDFGException("ConvNode strides cannot be empty");
×
103
    }
×
104
    if (pads_.empty()) {
82✔
105
        throw InvalidSDFGException("ConvNode pads cannot be empty");
×
106
    }
×
107
    if (dilations_.empty()) {
82✔
108
        throw InvalidSDFGException("ConvNode dilations cannot be empty");
×
109
    }
×
110

111
    // Validate consistent dimensions
112
    size_t spatial_dims = kernel_shape_.size();
82✔
113

114
    if (shape_.size() != spatial_dims + 2) {
82✔
115
        throw InvalidSDFGException("ConvNode shape must match kernel spatial dimensions + 2");
×
116
    }
×
117

118
    if (strides_.size() != spatial_dims) {
82✔
119
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
120
    }
1✔
121

122
    if (pads_.size() != 2 * spatial_dims) {
81✔
123
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
124
    }
1✔
125

126
    if (dilations_.size() != spatial_dims) {
80✔
127
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
128
    }
×
129

130
    // Validate groups
131
    if (SymEngine::is_a<SymEngine::Integer>(*this->group_)) {
80✔
132
        auto group_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->group_)->as_int();
80✔
133
        if (SymEngine::is_a<SymEngine::Integer>(*this->shape_[1])) {
80✔
134
            auto input_channels_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->shape_[1])->as_int();
80✔
135
            if (input_channels_int % group_int != 0) {
80✔
136
                throw InvalidSDFGException("ConvNode input channels must be divisible by groups");
×
137
            }
×
138
        }
80✔
139
        if (SymEngine::is_a<SymEngine::Integer>(*this->output_channels_)) {
80✔
140
            auto output_channels_int =
80✔
141
                SymEngine::rcp_static_cast<const SymEngine::Integer>(this->output_channels_)->as_int();
80✔
142
            if (output_channels_int % group_int != 0) {
80✔
143
                throw InvalidSDFGException("ConvNode output channels must be divisible by groups");
×
144
            }
×
145
        }
80✔
146
    }
80✔
147
}
80✔
148

149
blas::BLAS_Precision ConvNode::get_blas_precision(types::Scalar base_type) {
17✔
150
    switch (base_type.primitive_type()) {
17✔
151
        case types::PrimitiveType::Half:
×
152
            return blas::BLAS_Precision::h;
×
153
        case types::PrimitiveType::Float:
17✔
154
            return blas::BLAS_Precision::s;
17✔
155
        case types::PrimitiveType::Double:
×
156
            return blas::BLAS_Precision::d;
×
157
        default:
×
158
            return blas::BLAS_Precision::invalid;
×
159
    }
17✔
160
}
17✔
161

162
symbolic::MultiExpression ConvNode::get_out_shape() {
17✔
163
    size_t dims = kernel_shape_.size();
17✔
164
    symbolic::MultiExpression out_shape;
17✔
165
    out_shape.reserve(dims);
17✔
166
    // out_shape[i] = (shape[i + 2] + pads[i] + pads[dims + i] - dilations[i] * (kernel_shape[i] - 1) - 1)
167
    //                 / strides[i] + 1
168
    for (size_t i = 0; i < dims; i++) {
49✔
169
        out_shape.push_back(symbolic::add(
32✔
170
            symbolic::div(
32✔
171
                symbolic::sub(
32✔
172
                    symbolic::
32✔
173
                        sub(symbolic::add(this->shape_[i + 2], symbolic::add(this->pads_[i], this->pads_[dims + i])),
32✔
174
                            symbolic::mul(this->dilations_[i], symbolic::sub(this->kernel_shape_[i], symbolic::one()))),
32✔
175
                    symbolic::one()
32✔
176
                ),
32✔
177
                this->strides_[i]
32✔
178
            ),
32✔
179
            symbolic::one()
32✔
180
        ));
32✔
181
    }
32✔
182
    return out_shape;
17✔
183
}
17✔
184

185
bool ConvNode::has_bias() const { return with_bias_; }
×
186

187
bool ConvNode::check_expandable(
188
    data_flow::DataFlowGraph& dfg, analysis::AnalysisManager& analysis_manager, ConvExpandPrerequisits& boundary
189
) const {
27✔
190
    if ((dfg.nodes().size() != 4 || dfg.edges().size() != 3) && (dfg.nodes().size() != 5 || dfg.edges().size() != 4)) {
27✔
191
        return false;
4✔
192
    }
4✔
193

194
    // Get edges
195
    boundary.iedge_X = dfg.in_edge_for_connector(*this, "X");
23✔
196
    boundary.iedge_W = dfg.in_edge_for_connector(*this, "W");
23✔
197
    boundary.iedge_B = with_bias_ ? dfg.in_edge_for_connector(*this, "B") : nullptr;
23✔
198
    boundary.iedge_Y = dfg.in_edge_for_connector(*this, "Y");
23✔
199
    if (!boundary.iedge_X || !boundary.iedge_W || !boundary.iedge_Y) {
23✔
200
        return false;
×
201
    }
×
202
    boundary.has_bias = boundary.iedge_B != nullptr;
23✔
203

204
    // Get access nodes
205
    boundary.access_X = dynamic_cast<const data_flow::AccessNode*>(&boundary.iedge_X->src());
23✔
206
    boundary.access_W = dynamic_cast<const data_flow::AccessNode*>(&boundary.iedge_W->src());
23✔
207
    boundary.access_B =
23✔
208
        (boundary.has_bias ? dynamic_cast<const data_flow::AccessNode*>(&boundary.iedge_B->src()) : nullptr);
23✔
209
    boundary.access_Y = dynamic_cast<const data_flow::AccessNode*>(&boundary.iedge_Y->src());
23✔
210
    if (!boundary.access_X || !boundary.access_W || (boundary.has_bias && !boundary.access_B) || !boundary.access_Y) {
23✔
211
        return false;
×
212
    }
×
213

214
    // Get block & its parent
215
    boundary.block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
23✔
216
    if (!boundary.block) {
23✔
217
        return false;
×
218
    }
×
219

220
    boundary.block_parent = dynamic_cast<structured_control_flow::Sequence*>(boundary.block->get_parent());
23✔
221
    if (!boundary.block_parent) {
23✔
222
        return false;
×
223
    }
×
224

225
    boundary.block_index = boundary.block_parent->index(*boundary.block);
23✔
226
    if (boundary.block_index >= boundary.block_parent->size()) {
23✔
227
        return false;
×
228
    }
×
229

230
    return true;
23✔
231
}
23✔
232

233
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
234
    // Validate nodes are standalone in the data flow graph
235
    auto& dfg = this->get_parent();
5✔
236
    ConvExpandPrerequisits b;
5✔
237
    if (!check_expandable(dfg, analysis_manager, b)) {
5✔
238
        return false;
×
239
    }
×
240

241
    // Determine BLAS precision
242

243
    types::Scalar base_type(this->primitive_type(dfg));
5✔
244
    blas::BLAS_Precision precision = get_blas_precision(base_type);
5✔
245
    if (precision == blas::BLAS_Precision::invalid) {
5✔
246
        return false;
×
247
    }
×
248

249
    // Create new sequence for expansion
250
    auto& new_sequence = builder.add_sequence_before(
5✔
251
        *b.block_parent, *b.block, b.block_parent->at(b.block_index).second.assignments(), b.block->debug_info()
5✔
252
    );
5✔
253

254
    // Dimensions, i.e., 1D, 2D, 3D, ...
255
    size_t dims = this->kernel_shape_.size();
5✔
256
    symbolic::MultiExpression out_shape = get_out_shape();
5✔
257
    types::Scalar indvar_type(types::PrimitiveType::Int64);
5✔
258

259
    auto in_channels = symbolic::div(this->shape_[1], this->group_);
5✔
260
    auto out_channels = symbolic::div(this->output_channels_, this->group_);
5✔
261

262
    // Add loop over batch size
263
    auto n_container = builder.find_new_name("_n");
5✔
264
    builder.add_container(n_container, indvar_type);
5✔
265
    auto n = symbolic::symbol(n_container);
5✔
266
    auto& loop_n = builder.add_map(
5✔
267
        new_sequence,
5✔
268
        n,
5✔
269
        symbolic::Lt(n, this->shape_[0]),
5✔
270
        symbolic::zero(),
5✔
271
        symbolic::add(n, symbolic::one()),
5✔
272
        ScheduleType_Sequential::create(),
5✔
273
        {},
5✔
274
        b.block->debug_info()
5✔
275
    );
5✔
276

277
    // Add loop over groups
278
    auto g_container = builder.find_new_name("_g");
5✔
279
    builder.add_container(g_container, indvar_type);
5✔
280
    auto g = symbolic::symbol(g_container);
5✔
281
    auto& loop_g = builder.add_map(
5✔
282
        loop_n.root(),
5✔
283
        g,
5✔
284
        symbolic::Lt(g, this->group_),
5✔
285
        symbolic::zero(),
5✔
286
        symbolic::add(g, symbolic::one()),
5✔
287
        ScheduleType_Sequential::create(),
5✔
288
        {},
5✔
289
        b.block->debug_info()
5✔
290
    );
5✔
291

292
    // Add patches container with malloc
293
    symbolic::Expression patches_size = in_channels;
5✔
294
    for (size_t i = 0; i < dims; i++) {
15✔
295
        patches_size = symbolic::mul(patches_size, symbolic::mul(this->kernel_shape_[i], out_shape[i]));
10✔
296
    }
10✔
297
    types::Pointer patches_type(base_type);
5✔
298
    auto patches_container = builder.find_new_name("_patches");
5✔
299
    builder.add_container(patches_container, patches_type);
5✔
300
    auto [patches_malloc_block, patches_malloc_node] = stdlib::add_malloc_block(
5✔
301
        builder,
5✔
302
        loop_g.root(),
5✔
303
        patches_container,
5✔
304
        symbolic::mul(patches_size, symbolic::size_of_type(base_type)),
5✔
305
        patches_type,
5✔
306
        this->debug_info()
5✔
307
    );
5✔
308

309
    // Add loop over channels
310
    structured_control_flow::Sequence* current_seq = &loop_g.root();
5✔
311
    auto c_container = builder.find_new_name("_c");
5✔
312
    builder.add_container(c_container, indvar_type);
5✔
313
    auto c = symbolic::symbol(c_container);
5✔
314
    auto& loop_c = builder.add_map(
5✔
315
        *current_seq,
5✔
316
        c,
5✔
317
        symbolic::Lt(c, in_channels),
5✔
318
        symbolic::zero(),
5✔
319
        symbolic::add(c, symbolic::one()),
5✔
320
        ScheduleType_Sequential::create(),
5✔
321
        {},
5✔
322
        b.block->debug_info()
5✔
323
    );
5✔
324
    current_seq = &loop_c.root();
5✔
325

326
    // Add loops over kernel shape
327
    symbolic::SymbolVec ks;
5✔
328
    ks.reserve(dims);
5✔
329
    for (size_t i = 0; i < dims; i++) {
15✔
330
        auto k_container = builder.find_new_name("_k");
10✔
331
        builder.add_container(k_container, indvar_type);
10✔
332
        auto k = symbolic::symbol(k_container);
10✔
333
        ks.push_back(k);
10✔
334
        auto& loop_k = builder.add_map(
10✔
335
            *current_seq,
10✔
336
            k,
10✔
337
            symbolic::Lt(k, this->kernel_shape_[i]),
10✔
338
            symbolic::zero(),
10✔
339
            symbolic::add(k, symbolic::one()),
10✔
340
            ScheduleType_Sequential::create(),
10✔
341
            {},
10✔
342
            b.block->debug_info()
10✔
343
        );
10✔
344
        current_seq = &loop_k.root();
10✔
345
    }
10✔
346

347
    // Add loops over output dimensions
348
    symbolic::SymbolVec os;
5✔
349
    os.reserve(dims);
5✔
350
    for (size_t i = 0; i < dims; i++) {
15✔
351
        auto o_container = builder.find_new_name("_o");
10✔
352
        builder.add_container(o_container, indvar_type);
10✔
353
        auto o = symbolic::symbol(o_container);
10✔
354
        os.push_back(o);
10✔
355
        auto& loop_o = builder.add_map(
10✔
356
            *current_seq,
10✔
357
            o,
10✔
358
            symbolic::Lt(o, out_shape[i]),
10✔
359
            symbolic::zero(),
10✔
360
            symbolic::add(o, symbolic::one()),
10✔
361
            ScheduleType_Sequential::create(),
10✔
362
            {},
10✔
363
            b.block->debug_info()
10✔
364
        );
10✔
365
        current_seq = &loop_o.root();
10✔
366
    }
10✔
367

368
    // Add if/else to stay in bounds for copying
369
    symbolic::MultiExpression is;
5✔
370
    is.reserve(dims);
5✔
371
    symbolic::Condition copy_condition = symbolic::__true__();
5✔
372
    symbolic::Condition zero_condition = symbolic::__false__();
5✔
373
    for (size_t i = 0; i < dims; i++) {
15✔
374
        auto i_expr = symbolic::
10✔
375
            add(symbolic::sub(symbolic::mul(os[i], this->strides_[i]), this->pads_[i]),
10✔
376
                symbolic::mul(ks[i], this->dilations_[i]));
10✔
377
        is.push_back(i_expr);
10✔
378
        copy_condition = symbolic::
10✔
379
            And(copy_condition,
10✔
380
                symbolic::And(symbolic::Lt(i_expr, this->shape_[i + 2]), symbolic::Ge(i_expr, symbolic::zero())));
10✔
381
        zero_condition = symbolic::
10✔
382
            Or(zero_condition,
10✔
383
               symbolic::Or(symbolic::Ge(i_expr, this->shape_[i + 2]), symbolic::Lt(i_expr, symbolic::zero())));
10✔
384
    }
10✔
385
    auto& branch = builder.add_if_else(*current_seq, {}, b.block->debug_info());
5✔
386
    auto& copy_case = builder.add_case(branch, copy_condition, b.block->debug_info());
5✔
387
    auto& zero_case = builder.add_case(branch, zero_condition, b.block->debug_info());
5✔
388

389
    // Determine patches subset & tensor type
390
    data_flow::Subset patches_subset;
5✔
391
    patches_subset.push_back(c);
5✔
392
    patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
5✔
393
    patches_subset.insert(patches_subset.end(), os.begin(), os.end());
5✔
394
    symbolic::MultiExpression patches_shape;
5✔
395
    patches_shape.push_back(in_channels);
5✔
396
    patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
5✔
397
    patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
5✔
398
    types::Tensor patches_tensor_type(base_type, patches_shape);
5✔
399

400
    // Determine subset for X
401
    data_flow::Subset subset_X;
5✔
402
    subset_X.push_back(n);
5✔
403
    subset_X.push_back(symbolic::add(symbolic::mul(in_channels, g), c));
5✔
404
    subset_X.insert(subset_X.end(), is.begin(), is.end());
5✔
405

406
    // Add copy from X to patches
407
    auto& copy_block = builder.add_block(copy_case, {}, b.block->debug_info());
5✔
408
    {
5✔
409
        auto& X_access = builder.add_access(copy_block, b.access_X->data(), b.access_X->debug_info());
5✔
410
        auto& patches_access = builder.add_access(copy_block, patches_container, this->debug_info());
5✔
411
        auto& tasklet =
5✔
412
            builder.add_tasklet(copy_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
413
        builder.add_computational_memlet(
5✔
414
            copy_block, X_access, tasklet, "_in", subset_X, b.iedge_X->base_type(), b.iedge_X->debug_info()
5✔
415
        );
5✔
416
        builder.add_computational_memlet(
5✔
417
            copy_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
418
        );
5✔
419
    }
5✔
420

421
    // Add zero assignment to patches
422
    auto& zero_block = builder.add_block(zero_case, {}, b.block->debug_info());
5✔
423
    {
5✔
424
        auto& constant_zero = builder.add_constant(zero_block, "0.0", base_type, this->debug_info());
5✔
425
        auto& patches_access = builder.add_access(zero_block, patches_container, this->debug_info());
5✔
426
        auto& tasklet =
5✔
427
            builder.add_tasklet(zero_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
428
        builder.add_computational_memlet(zero_block, constant_zero, tasklet, "_in", {}, base_type, this->debug_info());
5✔
429
        builder.add_computational_memlet(
5✔
430
            zero_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
431
        );
5✔
432
    }
5✔
433

434
    // Add reference to W
435
    auto ref_W_container = builder.find_new_name("_ref_W");
5✔
436
    types::Scalar ref_W_base_type(builder.subject().type(b.access_W->data()).primitive_type());
5✔
437
    types::Pointer ref_W_type(ref_W_base_type);
5✔
438
    builder.add_container(ref_W_container, ref_W_type);
5✔
439
    auto ref_W_subset = symbolic::mul(symbolic::mul(out_channels, g), in_channels);
5✔
440
    for (size_t i = 0; i < dims; i++) {
15✔
441
        ref_W_subset = symbolic::mul(ref_W_subset, this->kernel_shape_[i]);
10✔
442
    }
10✔
443
    auto& ref_W_block = builder.add_block(loop_g.root(), {}, b.block->debug_info());
5✔
444
    {
5✔
445
        auto& W_access = builder.add_access(ref_W_block, b.access_W->data(), b.access_W->debug_info());
5✔
446
        auto& ref_W_access = builder.add_access(ref_W_block, ref_W_container, b.access_W->debug_info());
5✔
447
        builder.add_reference_memlet(ref_W_block, W_access, ref_W_access, {ref_W_subset}, ref_W_type);
5✔
448
    }
5✔
449

450
    // Add reference to Y
451
    auto ref_Y_container = builder.find_new_name("_ref_Y");
5✔
452
    types::Scalar ref_Y_base_type(builder.subject().type(b.access_Y->data()).primitive_type());
5✔
453
    types::Pointer ref_Y_type(ref_Y_base_type);
5✔
454
    builder.add_container(ref_Y_container, ref_Y_type);
5✔
455
    auto ref_Y_subset = symbolic::add(symbolic::mul(this->output_channels_, n), symbolic::mul(out_channels, g));
5✔
456
    for (size_t i = 0; i < dims; i++) {
15✔
457
        ref_Y_subset = symbolic::mul(ref_Y_subset, out_shape[i]);
10✔
458
    }
10✔
459
    auto& ref_Y_block = builder.add_block(loop_g.root(), {}, b.block->debug_info());
5✔
460
    {
5✔
461
        auto& Y_access = builder.add_access(ref_Y_block, b.access_Y->data(), b.access_Y->debug_info());
5✔
462
        auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, b.access_Y->debug_info());
5✔
463
        builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_subset}, ref_Y_type);
5✔
464
    }
5✔
465

466
    // Add GEMM node
467
    auto& gemm_block = builder.add_block(loop_g.root(), {}, b.block->debug_info());
5✔
468
    {
5✔
469
        auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
5✔
470
        auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
5✔
471
        auto& ref_W_access = builder.add_access(gemm_block, ref_W_container, b.access_W->debug_info());
5✔
472
        auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
5✔
473
        auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, b.access_Y->debug_info());
5✔
474
        symbolic::Expression gemm_m = out_channels;
5✔
475
        symbolic::Expression gemm_n = symbolic::one();
5✔
476
        symbolic::Expression gemm_k = in_channels;
5✔
477
        for (size_t i = 0; i < dims; i++) {
15✔
478
            gemm_n = symbolic::mul(gemm_n, out_shape[i]);
10✔
479
            gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
10✔
480
        }
10✔
481
        auto& libnode = builder.add_library_node<blas::GEMMNode>(
5✔
482
            gemm_block,
5✔
483
            this->debug_info(),
5✔
484
            blas::ImplementationType_BLAS,
5✔
485
            precision, // precision
5✔
486
            blas::BLAS_Layout::RowMajor, // layout
5✔
487
            blas::BLAS_Transpose::No, // transA
5✔
488
            blas::BLAS_Transpose::No, // transB
5✔
489
            gemm_m, // m
5✔
490
            gemm_n, // n
5✔
491
            gemm_k, // k
5✔
492
            gemm_k, // lda
5✔
493
            gemm_n, // ldb
5✔
494
            gemm_n // ldc
5✔
495
        );
5✔
496
        builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
5✔
497
        builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
5✔
498
        builder
5✔
499
            .add_computational_memlet(gemm_block, ref_W_access, libnode, "__A", {}, ref_W_type, b.iedge_W->debug_info());
5✔
500
        builder
5✔
501
            .add_computational_memlet(gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info());
5✔
502
        builder.add_computational_memlet(
5✔
503
            gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, b.iedge_Y->debug_info()
5✔
504
        );
5✔
505
    }
5✔
506

507
    // Add bias if available
508
    if (b.has_bias) {
5✔
509
        // Add loop over output channels
510
        auto l_container = builder.find_new_name("_l");
×
511
        builder.add_container(l_container, indvar_type);
×
512
        auto l = symbolic::symbol(l_container);
×
513
        auto& loop_l = builder.add_map(
×
514
            loop_g.root(),
×
515
            l,
×
516
            symbolic::Lt(l, out_channels),
×
517
            symbolic::zero(),
×
518
            symbolic::add(l, symbolic::one()),
×
519
            ScheduleType_Sequential::create(),
×
520
            {},
×
521
            b.block->debug_info()
×
522
        );
×
523
        current_seq = &loop_l.root();
×
524

525
        // Add loops over output dimensions (again)
526
        for (size_t i = 0; i < dims; i++) {
×
527
            auto o_container = builder.find_new_name("_o");
×
528
            builder.add_container(o_container, indvar_type);
×
529
            auto o = symbolic::symbol(o_container);
×
530
            auto& loop_o = builder.add_map(
×
531
                *current_seq,
×
532
                o,
×
533
                symbolic::Lt(o, out_shape[i]),
×
534
                symbolic::zero(),
×
535
                symbolic::add(o, symbolic::one()),
×
536
                ScheduleType_Sequential::create(),
×
537
                {},
×
538
                b.block->debug_info()
×
539
            );
×
540
            current_seq = &loop_o.root();
×
541
            os[i] = o;
×
542
        }
×
543

544
        // Add bias to Y
545
        data_flow::Subset Y_subset;
×
546
        Y_subset.push_back(n);
×
547
        Y_subset.push_back(symbolic::add(symbolic::mul(out_channels, g), l));
×
548
        Y_subset.insert(Y_subset.end(), os.begin(), os.end());
×
549
        auto B_subset = symbolic::add(symbolic::mul(out_channels, g), l);
×
550
        auto& bias_block = builder.add_block(*current_seq, {}, b.block->debug_info());
×
551
        {
×
552
            auto& B_access = builder.add_access(bias_block, b.access_B->data(), b.access_B->debug_info());
×
553
            auto& Y_access_in = builder.add_access(bias_block, b.access_Y->data(), b.access_Y->debug_info());
×
554
            auto& Y_access_out = builder.add_access(bias_block, b.access_Y->data(), b.access_Y->debug_info());
×
555
            auto& tasklet =
×
556
                builder
×
557
                    .add_tasklet(bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info());
×
558
            builder.add_computational_memlet(
×
559
                bias_block, Y_access_in, tasklet, "_in1", Y_subset, b.iedge_Y->base_type(), this->debug_info()
×
560
            );
×
561
            builder.add_computational_memlet(
×
562
                bias_block, B_access, tasklet, "_in2", {B_subset}, b.iedge_B->base_type(), b.iedge_B->debug_info()
×
563
            );
×
564
            builder.add_computational_memlet(
×
565
                bias_block, tasklet, "_out", Y_access_out, Y_subset, b.iedge_Y->base_type(), b.iedge_Y->debug_info()
×
566
            );
×
567
        }
×
568
    }
×
569

570
    // Add free for patches container
571
    auto& patches_free_block = builder.add_block(loop_g.root(), {}, b.block->debug_info());
5✔
572
    {
5✔
573
        auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
574
        auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
5✔
575
        builder.add_computational_memlet(
5✔
576
            patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
5✔
577
        );
5✔
578
    }
5✔
579

580
    // Clean up the original block
581
    builder.clear_code_node_legacy(*b.block, *this);
5✔
582
    // WARNING: this has been deallocated at this point!!
583
    builder.remove_child(*b.block_parent, b.block_index + 1);
5✔
584

585
    return true;
5✔
586
}
5✔
587

588
symbolic::SymbolSet ConvNode::symbols() const {
13✔
589
    auto syms = SpatialTensorNode::symbols();
13✔
590
    for (auto& atom : symbolic::atoms(output_channels_)) {
13✔
591
        syms.insert(atom);
×
592
    }
×
593
    for (auto& atom : symbolic::atoms(group_)) {
13✔
594
        syms.insert(atom);
×
595
    }
×
596

597
    return syms;
13✔
598
}
13✔
599

600
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
601
    SpatialTensorNode::replace(old_expression, new_expression);
×
602
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
603
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
604
}
×
605

NEW
606
void ConvNode::replace(const symbolic::ExpressionMapping& replacements) {
×
NEW
607
    SpatialTensorNode::replace(replacements);
×
NEW
608
    output_channels_ = symbolic::subs(output_channels_, replacements);
×
NEW
609
    group_ = symbolic::subs(group_, replacements);
×
NEW
610
}
×
611

612
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
613
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
614
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
615
        element_id,
1✔
616
        this->debug_info(),
1✔
617
        vertex,
1✔
618
        parent,
1✔
619
        shape_,
1✔
620
        kernel_shape_,
1✔
621
        strides_,
1✔
622
        pads_,
1✔
623
        dilations_,
1✔
624
        output_channels_,
1✔
625
        group_,
1✔
626
        with_bias_,
1✔
627
        fixed_quantization_,
1✔
628
        implementation_type_
1✔
629
    ));
1✔
630
}
1✔
631

632
std::string ConvNode::toStr() const {
×
633
    std::stringstream result;
×
634
    result << "Conv(";
×
635
    SpatialTensorNode::operator<<(result);
×
636

637
    result << ", output_channels=" + output_channels_->__str__();
×
638
    result << ", group=" + group_->__str__() + ")";
×
639
    return result.str();
×
640
}
×
641

642
symbolic::Expression ConvNode::flop() const {
×
643
    // Total FLOPs = output_elements * K_conv (multiplications)
644
    //             + output_elements * (K_conv - 1) (additions)
645
    auto output_elems = num_output_elements();
×
646
    auto k_conv = kernel_iteration_count();
×
647

648
    auto mul_ops = symbolic::mul(output_elems, k_conv);
×
649
    auto add_ops = symbolic::mul(output_elems, symbolic::sub(k_conv, symbolic::one()));
×
650
    return symbolic::add(mul_ops, add_ops);
×
651
}
×
652

653
data_flow::PointerAccessType ConvNode::pointer_access_type(int input_idx) const {
×
654
    if (input_idx == 0) {
×
655
        return data_flow::PointerAccessMeta::create_full_write_only(symbolic::__nullptr__(), true);
×
656
    } else if (input_idx >= 1 && input_idx < inputs_.size()) {
×
657
        return data_flow::PointerAccessMeta::create_read_only(symbolic::__nullptr__(), true);
×
658
    } else {
×
659
        return TensorNode::pointer_access_type(input_idx);
×
660
    }
×
661
}
×
662

663
symbolic::Expression ConvNode::num_output_elements() const {
×
664
    // N * C_out * prod(output_spatial_dim(i))
665
    return symbolic::mul(symbolic::mul(shape_[0], output_channels_), output_spatial_volume());
×
666
}
×
667

668
symbolic::Expression ConvNode::kernel_iteration_count() const {
×
669
    // (C_in / group) * prod(kernel_shape_[i])
670
    return symbolic::mul(symbolic::div(shape_[1], group_), kernel_volume());
×
671
}
×
672

673
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
674
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
675
    nlohmann::json j;
×
676

677
    serializer::JSONSerializer serializer;
×
678
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
679
    j["group"] = serializer.expression(conv_node.group());
×
680
    j["with_bias"] = conv_node.has_bias();
×
681

682
    fill_base_values(conv_node, j);
×
683

684
    return j;
×
685
}
×
686

687
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
688
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
689
) {
×
690
    assert(j.contains("kernel_shape"));
×
691

692
    auto base = deserialize_base_values(j);
×
693

694
    auto bias_it = j.find("with_bias");
×
695
    bool with_bias = false;
×
696
    if (bias_it != j.end()) {
×
697
        with_bias = bias_it->get<bool>();
×
698
    }
×
699

700
    symbolic::Expression output_channels = symbolic::one();
×
701
    if (j.contains("output_channels")) {
×
702
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
703
    }
×
704

705
    symbolic::Expression group = symbolic::one();
×
706
    if (j.contains("group")) {
×
707
        group = symbolic::parse(j["group"].get<std::string>());
×
708
    }
×
709

710
    return builder.add_library_node<ConvNode>(
×
711
        parent,
×
712
        base.debug_info,
×
713
        base.shape,
×
714
        base.kernel_shape,
×
715
        base.strides,
×
716
        base.pads,
×
717
        base.dilations,
×
718
        output_channels,
×
719
        group,
×
720
        with_bias,
×
721
        base.quantization
×
722
    );
×
723
}
×
724

725
} // namespace tensor
726
} // namespace math
727
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc