• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27007027060

05 Jun 2026 09:28AM UTC coverage: 61.275% (-0.02%) from 61.292%
27007027060

push

github

web-flow
Improve Quantization support on TensorNodes (#736)

* Added DataFlowGraph.find_standalone_exit() following the pattern of find_standalone_entry() to abstract away edge types.
* LibNodeDispatcher allows no missing inputs.
  ConvNode explicitly is configured whether it has a bias or not to solve for this.
* Fixed elementwise CMath node toStr()

---------

Co-authored-by: Moritz Timmer <25349452+Moehre2@users.noreply.github.com>

10 of 43 new or added lines in 8 files covered. (23.26%)

1 existing line in 1 file now uncovered.

35592 of 58086 relevant lines covered (61.27%)

11015.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

68.11
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4
#include <sstream>
5
#include <utility>
6

7
#include "sdfg/analysis/analysis.h"
8
#include "sdfg/builder/structured_sdfg_builder.h"
9
#include "sdfg/data_flow/access_node.h"
10
#include "sdfg/data_flow/library_nodes/math/blas/blas_node.h"
11
#include "sdfg/data_flow/library_nodes/stdlib/free.h"
12
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
13
#include "sdfg/data_flow/library_nodes/stdlib/memset.h"
14
#include "sdfg/data_flow/memlet.h"
15
#include "sdfg/data_flow/tasklet.h"
16
#include "sdfg/exceptions.h"
17
#include "sdfg/structured_control_flow/block.h"
18
#include "sdfg/structured_control_flow/map.h"
19
#include "sdfg/structured_control_flow/sequence.h"
20
#include "sdfg/symbolic/symbolic.h"
21
#include "sdfg/types/pointer.h"
22
#include "sdfg/types/scalar.h"
23
#include "sdfg/types/tensor.h"
24
#include "sdfg/types/type.h"
25

26
#include "sdfg/analysis/scope_analysis.h"
27
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
28
#include "symengine/integer.h"
29
#include "symengine/symengine_rcp.h"
30

31
namespace sdfg {
32
namespace math {
33
namespace tensor {
34

35
ConvNode::ConvNode(
36
    size_t element_id,
37
    const DebugInfo& debug_info,
38
    const graph::Vertex vertex,
39
    data_flow::DataFlowGraph& parent,
40
    const std::vector<symbolic::Expression>& shape,
41
    const std::vector<symbolic::Expression>& kernel_shape,
42
    const std::vector<symbolic::Expression>& strides,
43
    const std::vector<symbolic::Expression>& pads,
44
    const std::vector<symbolic::Expression>& dilations,
45
    symbolic::Expression output_channels,
46
    symbolic::Expression group,
47
    bool with_bias,
48
    QuantizationType quantization,
49
    const data_flow::ImplementationType& impl_type
50
)
51
    : SpatialTensorNode(
37✔
52
          element_id,
37✔
53
          debug_info,
37✔
54
          vertex,
37✔
55
          parent,
37✔
56
          LibraryNodeType_Conv,
37✔
57
          {},
37✔
58
          {"Y", "X", "W"}, // X and W are required, B (bias) is optional
37✔
59
          impl_type,
37✔
60
          quantization,
37✔
61
          shape,
37✔
62
          kernel_shape,
37✔
63
          strides,
37✔
64
          pads,
37✔
65
          dilations
37✔
66
      ),
37✔
67
      output_channels_(std::move(output_channels)), group_(std::move(group)), with_bias_(with_bias) {
37✔
68
    if (with_bias) {
37✔
69
        inputs_.push_back("B");
1✔
70
    }
1✔
71
}
37✔
72

73
void ConvNode::validate(const Function& function) const {
73✔
74
    TensorNode::validate(function);
73✔
75

76
    auto& graph = this->get_parent();
73✔
77

78
    // Custom validation for ConvNode that handles optional bias input
79
    // We expect X, W as required inputs and optionally B (bias)
80

81
    // Collect all input edges by connector name
82
    std::map<std::string, const data_flow::Memlet*> input_edges;
73✔
83
    for (auto& iedge : graph.in_edges(*this)) {
219✔
84
        input_edges[iedge.dst_conn()] = &iedge;
219✔
85
    }
219✔
86

87
    // Check that required inputs X and W are present
88
    if (input_edges.find("X") == input_edges.end()) {
73✔
89
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
90
    }
×
91
    if (input_edges.find("W") == input_edges.end()) {
73✔
92
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
93
    }
×
94

95
    // Validate that parameters are not empty
96
    if (shape_.empty()) {
73✔
97
        throw InvalidSDFGException("ConvNode shape cannot be empty");
×
98
    }
×
99
    if (kernel_shape_.empty()) {
73✔
100
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
101
    }
×
102
    if (strides_.empty()) {
73✔
103
        throw InvalidSDFGException("ConvNode strides cannot be empty");
×
104
    }
×
105
    if (pads_.empty()) {
73✔
106
        throw InvalidSDFGException("ConvNode pads cannot be empty");
×
107
    }
×
108
    if (dilations_.empty()) {
73✔
109
        throw InvalidSDFGException("ConvNode dilations cannot be empty");
×
110
    }
×
111

112
    // Validate consistent dimensions
113
    size_t spatial_dims = kernel_shape_.size();
73✔
114

115
    if (shape_.size() != spatial_dims + 2) {
73✔
116
        throw InvalidSDFGException("ConvNode shape must match kernel spatial dimensions + 2");
×
117
    }
×
118

119
    if (strides_.size() != spatial_dims) {
73✔
120
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
121
    }
1✔
122

123
    if (pads_.size() != 2 * spatial_dims) {
72✔
124
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
125
    }
1✔
126

127
    if (dilations_.size() != spatial_dims) {
71✔
128
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
129
    }
×
130

131
    // Validate groups
132
    if (SymEngine::is_a<SymEngine::Integer>(*this->group_)) {
71✔
133
        auto group_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->group_)->as_int();
71✔
134
        if (SymEngine::is_a<SymEngine::Integer>(*this->shape_[1])) {
71✔
135
            auto input_channels_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(this->shape_[1])->as_int();
71✔
136
            if (input_channels_int % group_int != 0) {
71✔
137
                throw InvalidSDFGException("ConvNode input channels must be divisible by groups");
×
138
            }
×
139
        }
71✔
140
        if (SymEngine::is_a<SymEngine::Integer>(*this->output_channels_)) {
71✔
141
            auto output_channels_int =
71✔
142
                SymEngine::rcp_static_cast<const SymEngine::Integer>(this->output_channels_)->as_int();
71✔
143
            if (output_channels_int % group_int != 0) {
71✔
144
                throw InvalidSDFGException("ConvNode output channels must be divisible by groups");
×
145
            }
×
146
        }
71✔
147
    }
71✔
148
}
71✔
149

150
blas::BLAS_Precision ConvNode::get_blas_precision(types::Scalar base_type) {
7✔
151
    switch (base_type.primitive_type()) {
7✔
152
        case types::PrimitiveType::Half:
×
153
            return blas::BLAS_Precision::h;
×
154
        case types::PrimitiveType::Float:
7✔
155
            return blas::BLAS_Precision::s;
7✔
156
        case types::PrimitiveType::Double:
×
157
            return blas::BLAS_Precision::d;
×
158
        default:
×
159
            return blas::BLAS_Precision::invalid;
×
160
    }
7✔
161
}
7✔
162

163
symbolic::MultiExpression ConvNode::get_out_shape() {
7✔
164
    size_t dims = kernel_shape_.size();
7✔
165
    symbolic::MultiExpression out_shape;
7✔
166
    out_shape.reserve(dims);
7✔
167
    // out_shape[i] = (shape[i + 2] + pads[i] + pads[dims + i] - dilations[i] * (kernel_shape[i] - 1) - 1)
168
    //                 / strides[i] + 1
169
    for (size_t i = 0; i < dims; i++) {
20✔
170
        out_shape.push_back(symbolic::add(
13✔
171
            symbolic::div(
13✔
172
                symbolic::sub(
13✔
173
                    symbolic::
13✔
174
                        sub(symbolic::add(this->shape_[i + 2], symbolic::add(this->pads_[i], this->pads_[dims + i])),
13✔
175
                            symbolic::mul(this->dilations_[i], symbolic::sub(this->kernel_shape_[i], symbolic::one()))),
13✔
176
                    symbolic::one()
13✔
177
                ),
13✔
178
                this->strides_[i]
13✔
179
            ),
13✔
180
            symbolic::one()
13✔
181
        ));
13✔
182
    }
13✔
183
    return out_shape;
7✔
184
}
7✔
185

NEW
186
bool ConvNode::has_bias() const { return with_bias_; }
×
187

188
bool ConvNode::check_expandable(
189
    data_flow::DataFlowGraph& dfg, analysis::AnalysisManager& analysis_manager, ConvExpandPrerequisits& boundary
190
) const {
9✔
191
    if ((dfg.nodes().size() != 4 || dfg.edges().size() != 3) && (dfg.nodes().size() != 5 || dfg.edges().size() != 4)) {
9✔
192
        return false;
1✔
193
    }
1✔
194

195
    // Get edges
196
    boundary.iedge_X = dfg.in_edge_for_connector(*this, "X");
8✔
197
    boundary.iedge_W = dfg.in_edge_for_connector(*this, "W");
8✔
198
    boundary.iedge_B = with_bias_ ? dfg.in_edge_for_connector(*this, "B") : nullptr;
8✔
199
    boundary.iedge_Y = dfg.in_edge_for_connector(*this, "Y");
8✔
200
    if (!boundary.iedge_X || !boundary.iedge_W || !boundary.iedge_Y) {
8✔
201
        return false;
×
202
    }
×
203
    boundary.has_bias = boundary.iedge_B != nullptr;
8✔
204

205
    // Get access nodes
206
    boundary.access_X = dynamic_cast<const data_flow::AccessNode*>(&boundary.iedge_X->src());
8✔
207
    boundary.access_W = dynamic_cast<const data_flow::AccessNode*>(&boundary.iedge_W->src());
8✔
208
    boundary.access_B =
8✔
209
        (boundary.has_bias ? dynamic_cast<const data_flow::AccessNode*>(&boundary.iedge_B->src()) : nullptr);
8✔
210
    boundary.access_Y = dynamic_cast<const data_flow::AccessNode*>(&boundary.iedge_Y->src());
8✔
211
    if (!boundary.access_X || !boundary.access_W || (boundary.has_bias && !boundary.access_B) || !boundary.access_Y) {
8✔
212
        return false;
×
213
    }
×
214

215
    // Get block & its parent
216
    boundary.block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
8✔
217
    if (!boundary.block) {
8✔
218
        return false;
×
219
    }
×
220

221
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
8✔
222
    boundary.block_parent = dynamic_cast<structured_control_flow::Sequence*>(scope_analysis.parent_scope(boundary.block)
8✔
223
    );
8✔
224
    if (!boundary.block_parent) {
8✔
225
        return false;
×
226
    }
×
227

228
    boundary.block_index = boundary.block_parent->index(*boundary.block);
8✔
229
    if (boundary.block_index >= boundary.block_parent->size()) {
8✔
230
        return false;
×
231
    }
×
232

233
    return true;
8✔
234
}
8✔
235

236
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
237
    // Validate nodes are standalone in the data flow graph
238
    auto& dfg = this->get_parent();
5✔
239
    ConvExpandPrerequisits b;
5✔
240
    if (!check_expandable(dfg, analysis_manager, b)) {
5✔
241
        return false;
×
242
    }
×
243

244
    // Determine BLAS precision
245

246
    types::Scalar base_type(this->primitive_type(dfg));
5✔
247
    blas::BLAS_Precision precision = get_blas_precision(base_type);
5✔
248
    if (precision == blas::BLAS_Precision::invalid) {
5✔
249
        return false;
×
250
    }
×
251

252
    // Create new sequence for expansion
253
    auto& new_sequence = builder.add_sequence_before(
5✔
254
        *b.block_parent, *b.block, b.block_parent->at(b.block_index).second.assignments(), b.block->debug_info()
5✔
255
    );
5✔
256

257
    // Dimensions, i.e., 1D, 2D, 3D, ...
258
    size_t dims = this->kernel_shape_.size();
5✔
259
    symbolic::MultiExpression out_shape = get_out_shape();
5✔
260
    types::Scalar indvar_type(types::PrimitiveType::Int64);
5✔
261

262
    auto in_channels = symbolic::div(this->shape_[1], this->group_);
5✔
263
    auto out_channels = symbolic::div(this->output_channels_, this->group_);
5✔
264

265
    // Add loop over batch size
266
    auto n_container = builder.find_new_name("_n");
5✔
267
    builder.add_container(n_container, indvar_type);
5✔
268
    auto n = symbolic::symbol(n_container);
5✔
269
    auto& loop_n = builder.add_map(
5✔
270
        new_sequence,
5✔
271
        n,
5✔
272
        symbolic::Lt(n, this->shape_[0]),
5✔
273
        symbolic::zero(),
5✔
274
        symbolic::add(n, symbolic::one()),
5✔
275
        ScheduleType_Sequential::create(),
5✔
276
        {},
5✔
277
        b.block->debug_info()
5✔
278
    );
5✔
279

280
    // Add loop over groups
281
    auto g_container = builder.find_new_name("_g");
5✔
282
    builder.add_container(g_container, indvar_type);
5✔
283
    auto g = symbolic::symbol(g_container);
5✔
284
    auto& loop_g = builder.add_map(
5✔
285
        loop_n.root(),
5✔
286
        g,
5✔
287
        symbolic::Lt(g, this->group_),
5✔
288
        symbolic::zero(),
5✔
289
        symbolic::add(g, symbolic::one()),
5✔
290
        ScheduleType_Sequential::create(),
5✔
291
        {},
5✔
292
        b.block->debug_info()
5✔
293
    );
5✔
294

295
    // Add patches container with malloc
296
    symbolic::Expression patches_size = in_channels;
5✔
297
    for (size_t i = 0; i < dims; i++) {
15✔
298
        patches_size = symbolic::mul(patches_size, symbolic::mul(this->kernel_shape_[i], out_shape[i]));
10✔
299
    }
10✔
300
    types::Pointer patches_type(base_type);
5✔
301
    auto patches_container = builder.find_new_name("_patches");
5✔
302
    builder.add_container(patches_container, patches_type);
5✔
303
    auto [patches_malloc_block, patches_malloc_node] = stdlib::add_malloc_block(
5✔
304
        builder,
5✔
305
        loop_g.root(),
5✔
306
        patches_container,
5✔
307
        symbolic::mul(patches_size, symbolic::size_of_type(base_type)),
5✔
308
        patches_type,
5✔
309
        this->debug_info()
5✔
310
    );
5✔
311

312
    // Add loop over channels
313
    structured_control_flow::Sequence* current_seq = &loop_g.root();
5✔
314
    auto c_container = builder.find_new_name("_c");
5✔
315
    builder.add_container(c_container, indvar_type);
5✔
316
    auto c = symbolic::symbol(c_container);
5✔
317
    auto& loop_c = builder.add_map(
5✔
318
        *current_seq,
5✔
319
        c,
5✔
320
        symbolic::Lt(c, in_channels),
5✔
321
        symbolic::zero(),
5✔
322
        symbolic::add(c, symbolic::one()),
5✔
323
        ScheduleType_Sequential::create(),
5✔
324
        {},
5✔
325
        b.block->debug_info()
5✔
326
    );
5✔
327
    current_seq = &loop_c.root();
5✔
328

329
    // Add loops over kernel shape
330
    symbolic::SymbolVec ks;
5✔
331
    ks.reserve(dims);
5✔
332
    for (size_t i = 0; i < dims; i++) {
15✔
333
        auto k_container = builder.find_new_name("_k");
10✔
334
        builder.add_container(k_container, indvar_type);
10✔
335
        auto k = symbolic::symbol(k_container);
10✔
336
        ks.push_back(k);
10✔
337
        auto& loop_k = builder.add_map(
10✔
338
            *current_seq,
10✔
339
            k,
10✔
340
            symbolic::Lt(k, this->kernel_shape_[i]),
10✔
341
            symbolic::zero(),
10✔
342
            symbolic::add(k, symbolic::one()),
10✔
343
            ScheduleType_Sequential::create(),
10✔
344
            {},
10✔
345
            b.block->debug_info()
10✔
346
        );
10✔
347
        current_seq = &loop_k.root();
10✔
348
    }
10✔
349

350
    // Add loops over output dimensions
351
    symbolic::SymbolVec os;
5✔
352
    os.reserve(dims);
5✔
353
    for (size_t i = 0; i < dims; i++) {
15✔
354
        auto o_container = builder.find_new_name("_o");
10✔
355
        builder.add_container(o_container, indvar_type);
10✔
356
        auto o = symbolic::symbol(o_container);
10✔
357
        os.push_back(o);
10✔
358
        auto& loop_o = builder.add_map(
10✔
359
            *current_seq,
10✔
360
            o,
10✔
361
            symbolic::Lt(o, out_shape[i]),
10✔
362
            symbolic::zero(),
10✔
363
            symbolic::add(o, symbolic::one()),
10✔
364
            ScheduleType_Sequential::create(),
10✔
365
            {},
10✔
366
            b.block->debug_info()
10✔
367
        );
10✔
368
        current_seq = &loop_o.root();
10✔
369
    }
10✔
370

371
    // Add if/else to stay in bounds for copying
372
    symbolic::MultiExpression is;
5✔
373
    is.reserve(dims);
5✔
374
    symbolic::Condition copy_condition = symbolic::__true__();
5✔
375
    symbolic::Condition zero_condition = symbolic::__false__();
5✔
376
    for (size_t i = 0; i < dims; i++) {
15✔
377
        auto i_expr = symbolic::
10✔
378
            add(symbolic::sub(symbolic::mul(os[i], this->strides_[i]), this->pads_[i]),
10✔
379
                symbolic::mul(ks[i], this->dilations_[i]));
10✔
380
        is.push_back(i_expr);
10✔
381
        copy_condition = symbolic::
10✔
382
            And(copy_condition,
10✔
383
                symbolic::And(symbolic::Lt(i_expr, this->shape_[i + 2]), symbolic::Ge(i_expr, symbolic::zero())));
10✔
384
        zero_condition = symbolic::
10✔
385
            Or(zero_condition,
10✔
386
               symbolic::Or(symbolic::Ge(i_expr, this->shape_[i + 2]), symbolic::Lt(i_expr, symbolic::zero())));
10✔
387
    }
10✔
388
    auto& branch = builder.add_if_else(*current_seq, {}, b.block->debug_info());
5✔
389
    auto& copy_case = builder.add_case(branch, copy_condition, b.block->debug_info());
5✔
390
    auto& zero_case = builder.add_case(branch, zero_condition, b.block->debug_info());
5✔
391

392
    // Determine patches subset & tensor type
393
    data_flow::Subset patches_subset;
5✔
394
    patches_subset.push_back(c);
5✔
395
    patches_subset.insert(patches_subset.end(), ks.begin(), ks.end());
5✔
396
    patches_subset.insert(patches_subset.end(), os.begin(), os.end());
5✔
397
    symbolic::MultiExpression patches_shape;
5✔
398
    patches_shape.push_back(in_channels);
5✔
399
    patches_shape.insert(patches_shape.end(), this->kernel_shape_.begin(), this->kernel_shape_.end());
5✔
400
    patches_shape.insert(patches_shape.end(), out_shape.begin(), out_shape.end());
5✔
401
    types::Tensor patches_tensor_type(base_type, patches_shape);
5✔
402

403
    // Determine subset for X
404
    data_flow::Subset subset_X;
5✔
405
    subset_X.push_back(n);
5✔
406
    subset_X.push_back(symbolic::add(symbolic::mul(in_channels, g), c));
5✔
407
    subset_X.insert(subset_X.end(), is.begin(), is.end());
5✔
408

409
    // Add copy from X to patches
410
    auto& copy_block = builder.add_block(copy_case, {}, b.block->debug_info());
5✔
411
    {
5✔
412
        auto& X_access = builder.add_access(copy_block, b.access_X->data(), b.access_X->debug_info());
5✔
413
        auto& patches_access = builder.add_access(copy_block, patches_container, this->debug_info());
5✔
414
        auto& tasklet =
5✔
415
            builder.add_tasklet(copy_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
416
        builder.add_computational_memlet(
5✔
417
            copy_block, X_access, tasklet, "_in", subset_X, b.iedge_X->base_type(), b.iedge_X->debug_info()
5✔
418
        );
5✔
419
        builder.add_computational_memlet(
5✔
420
            copy_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
421
        );
5✔
422
    }
5✔
423

424
    // Add zero assignment to patches
425
    auto& zero_block = builder.add_block(zero_case, {}, b.block->debug_info());
5✔
426
    {
5✔
427
        auto& constant_zero = builder.add_constant(zero_block, "0.0", base_type, this->debug_info());
5✔
428
        auto& patches_access = builder.add_access(zero_block, patches_container, this->debug_info());
5✔
429
        auto& tasklet =
5✔
430
            builder.add_tasklet(zero_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
5✔
431
        builder.add_computational_memlet(zero_block, constant_zero, tasklet, "_in", {}, base_type, this->debug_info());
5✔
432
        builder.add_computational_memlet(
5✔
433
            zero_block, tasklet, "_out", patches_access, patches_subset, patches_tensor_type, this->debug_info()
5✔
434
        );
5✔
435
    }
5✔
436

437
    // Add reference to W
438
    auto ref_W_container = builder.find_new_name("_ref_W");
5✔
439
    types::Scalar ref_W_base_type(builder.subject().type(b.access_W->data()).primitive_type());
5✔
440
    types::Pointer ref_W_type(ref_W_base_type);
5✔
441
    builder.add_container(ref_W_container, ref_W_type);
5✔
442
    auto ref_W_subset = symbolic::mul(symbolic::mul(out_channels, g), in_channels);
5✔
443
    for (size_t i = 0; i < dims; i++) {
15✔
444
        ref_W_subset = symbolic::mul(ref_W_subset, this->kernel_shape_[i]);
10✔
445
    }
10✔
446
    auto& ref_W_block = builder.add_block(loop_g.root(), {}, b.block->debug_info());
5✔
447
    {
5✔
448
        auto& W_access = builder.add_access(ref_W_block, b.access_W->data(), b.access_W->debug_info());
5✔
449
        auto& ref_W_access = builder.add_access(ref_W_block, ref_W_container, b.access_W->debug_info());
5✔
450
        builder.add_reference_memlet(ref_W_block, W_access, ref_W_access, {ref_W_subset}, ref_W_type);
5✔
451
    }
5✔
452

453
    // Add reference to Y
454
    auto ref_Y_container = builder.find_new_name("_ref_Y");
5✔
455
    types::Scalar ref_Y_base_type(builder.subject().type(b.access_Y->data()).primitive_type());
5✔
456
    types::Pointer ref_Y_type(ref_Y_base_type);
5✔
457
    builder.add_container(ref_Y_container, ref_Y_type);
5✔
458
    auto ref_Y_subset = symbolic::add(symbolic::mul(this->output_channels_, n), symbolic::mul(out_channels, g));
5✔
459
    for (size_t i = 0; i < dims; i++) {
15✔
460
        ref_Y_subset = symbolic::mul(ref_Y_subset, out_shape[i]);
10✔
461
    }
10✔
462
    auto& ref_Y_block = builder.add_block(loop_g.root(), {}, b.block->debug_info());
5✔
463
    {
5✔
464
        auto& Y_access = builder.add_access(ref_Y_block, b.access_Y->data(), b.access_Y->debug_info());
5✔
465
        auto& ref_Y_access = builder.add_access(ref_Y_block, ref_Y_container, b.access_Y->debug_info());
5✔
466
        builder.add_reference_memlet(ref_Y_block, Y_access, ref_Y_access, {ref_Y_subset}, ref_Y_type);
5✔
467
    }
5✔
468

469
    // Add GEMM node
470
    auto& gemm_block = builder.add_block(loop_g.root(), {}, b.block->debug_info());
5✔
471
    {
5✔
472
        auto& alpha = builder.add_constant(gemm_block, "1.0", base_type, this->debug_info());
5✔
473
        auto& beta = builder.add_constant(gemm_block, "0.0", base_type, this->debug_info());
5✔
474
        auto& ref_W_access = builder.add_access(gemm_block, ref_W_container, b.access_W->debug_info());
5✔
475
        auto& patches_access = builder.add_access(gemm_block, patches_container, this->debug_info());
5✔
476
        auto& ref_Y_access_in = builder.add_access(gemm_block, ref_Y_container, b.access_Y->debug_info());
5✔
477
        symbolic::Expression gemm_m = out_channels;
5✔
478
        symbolic::Expression gemm_n = symbolic::one();
5✔
479
        symbolic::Expression gemm_k = in_channels;
5✔
480
        for (size_t i = 0; i < dims; i++) {
15✔
481
            gemm_n = symbolic::mul(gemm_n, out_shape[i]);
10✔
482
            gemm_k = symbolic::mul(gemm_k, this->kernel_shape_[i]);
10✔
483
        }
10✔
484
        auto& libnode = builder.add_library_node<blas::GEMMNode>(
5✔
485
            gemm_block,
5✔
486
            this->debug_info(),
5✔
487
            blas::ImplementationType_BLAS,
5✔
488
            precision, // precision
5✔
489
            blas::BLAS_Layout::RowMajor, // layout
5✔
490
            blas::BLAS_Transpose::No, // transA
5✔
491
            blas::BLAS_Transpose::No, // transB
5✔
492
            gemm_m, // m
5✔
493
            gemm_n, // n
5✔
494
            gemm_k, // k
5✔
495
            gemm_k, // lda
5✔
496
            gemm_n, // ldb
5✔
497
            gemm_n // ldc
5✔
498
        );
5✔
499
        builder.add_computational_memlet(gemm_block, alpha, libnode, "__alpha", {}, base_type, this->debug_info());
5✔
500
        builder.add_computational_memlet(gemm_block, beta, libnode, "__beta", {}, base_type, this->debug_info());
5✔
501
        builder
5✔
502
            .add_computational_memlet(gemm_block, ref_W_access, libnode, "__A", {}, ref_W_type, b.iedge_W->debug_info());
5✔
503
        builder
5✔
504
            .add_computational_memlet(gemm_block, patches_access, libnode, "__B", {}, patches_type, this->debug_info());
5✔
505
        builder.add_computational_memlet(
5✔
506
            gemm_block, ref_Y_access_in, libnode, "__C", {}, ref_Y_type, b.iedge_Y->debug_info()
5✔
507
        );
5✔
508
    }
5✔
509

510
    // Add bias if available
511
    if (b.has_bias) {
5✔
512
        // Add loop over output channels
513
        auto l_container = builder.find_new_name("_l");
×
514
        builder.add_container(l_container, indvar_type);
×
515
        auto l = symbolic::symbol(l_container);
×
516
        auto& loop_l = builder.add_map(
×
517
            loop_g.root(),
×
518
            l,
×
519
            symbolic::Lt(l, out_channels),
×
520
            symbolic::zero(),
×
521
            symbolic::add(l, symbolic::one()),
×
522
            ScheduleType_Sequential::create(),
×
523
            {},
×
524
            b.block->debug_info()
×
525
        );
×
526
        current_seq = &loop_l.root();
×
527

528
        // Add loops over output dimensions (again)
529
        for (size_t i = 0; i < dims; i++) {
×
530
            auto o_container = builder.find_new_name("_o");
×
531
            builder.add_container(o_container, indvar_type);
×
532
            auto o = symbolic::symbol(o_container);
×
533
            auto& loop_o = builder.add_map(
×
534
                *current_seq,
×
535
                o,
×
536
                symbolic::Lt(o, out_shape[i]),
×
537
                symbolic::zero(),
×
538
                symbolic::add(o, symbolic::one()),
×
539
                ScheduleType_Sequential::create(),
×
540
                {},
×
541
                b.block->debug_info()
×
542
            );
×
543
            current_seq = &loop_o.root();
×
544
            os[i] = o;
×
545
        }
×
546

547
        // Add bias to Y
548
        data_flow::Subset Y_subset;
×
549
        Y_subset.push_back(n);
×
550
        Y_subset.push_back(symbolic::add(symbolic::mul(out_channels, g), l));
×
551
        Y_subset.insert(Y_subset.end(), os.begin(), os.end());
×
552
        auto B_subset = symbolic::add(symbolic::mul(out_channels, g), l);
×
553
        auto& bias_block = builder.add_block(*current_seq, {}, b.block->debug_info());
×
554
        {
×
555
            auto& B_access = builder.add_access(bias_block, b.access_B->data(), b.access_B->debug_info());
×
556
            auto& Y_access_in = builder.add_access(bias_block, b.access_Y->data(), b.access_Y->debug_info());
×
557
            auto& Y_access_out = builder.add_access(bias_block, b.access_Y->data(), b.access_Y->debug_info());
×
558
            auto& tasklet =
×
559
                builder
×
560
                    .add_tasklet(bias_block, data_flow::TaskletCode::fp_add, "_out", {"_in1", "_in2"}, this->debug_info());
×
561
            builder.add_computational_memlet(
×
562
                bias_block, Y_access_in, tasklet, "_in1", Y_subset, b.iedge_Y->base_type(), this->debug_info()
×
563
            );
×
564
            builder.add_computational_memlet(
×
565
                bias_block, B_access, tasklet, "_in2", {B_subset}, b.iedge_B->base_type(), b.iedge_B->debug_info()
×
566
            );
×
567
            builder.add_computational_memlet(
×
568
                bias_block, tasklet, "_out", Y_access_out, Y_subset, b.iedge_Y->base_type(), b.iedge_Y->debug_info()
×
569
            );
×
570
        }
×
571
    }
×
572

573
    // Add free for patches container
574
    auto& patches_free_block = builder.add_block(loop_g.root(), {}, b.block->debug_info());
5✔
575
    {
5✔
576
        auto& patches_access_in = builder.add_access(patches_free_block, patches_container, this->debug_info());
5✔
577
        auto& libnode = builder.add_library_node<stdlib::FreeNode>(patches_free_block, this->debug_info());
5✔
578
        builder.add_computational_memlet(
5✔
579
            patches_free_block, patches_access_in, libnode, "_ptr", {}, patches_type, this->debug_info()
5✔
580
        );
5✔
581
    }
5✔
582

583
    // Clean up the original block
584
    builder.clear_code_node_legacy(*b.block, *this);
5✔
585
    // WARNING: this has been deallocated at this point!!
586
    builder.remove_child(*b.block_parent, b.block_index + 1);
5✔
587

588
    return true;
5✔
589
}
5✔
590

591
symbolic::SymbolSet ConvNode::symbols() const {
13✔
592
    auto syms = SpatialTensorNode::symbols();
13✔
593
    for (auto& atom : symbolic::atoms(output_channels_)) {
13✔
594
        syms.insert(atom);
×
595
    }
×
596
    for (auto& atom : symbolic::atoms(group_)) {
13✔
597
        syms.insert(atom);
×
598
    }
×
599

600
    return syms;
13✔
601
}
13✔
602

603
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
604
    SpatialTensorNode::replace(old_expression, new_expression);
×
605
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
606
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
607
}
×
608

609
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
610
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
611
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
612
        element_id,
1✔
613
        this->debug_info(),
1✔
614
        vertex,
1✔
615
        parent,
1✔
616
        shape_,
1✔
617
        kernel_shape_,
1✔
618
        strides_,
1✔
619
        pads_,
1✔
620
        dilations_,
1✔
621
        output_channels_,
1✔
622
        group_,
1✔
623
        with_bias_,
1✔
624
        fixed_quantization_,
1✔
625
        implementation_type_
1✔
626
    ));
1✔
627
}
1✔
628

629
std::string ConvNode::toStr() const {
×
630
    std::stringstream result;
×
631
    result << "Conv(";
×
632
    SpatialTensorNode::operator<<(result);
×
633

634
    result << ", output_channels=" + output_channels_->__str__();
×
635
    result << ", group=" + group_->__str__() + ")";
×
636
    return result.str();
×
637
}
×
638

639
symbolic::Expression ConvNode::flop() const {
×
640
    // Total FLOPs = output_elements * K_conv (multiplications)
641
    //             + output_elements * (K_conv - 1) (additions)
642
    auto output_elems = num_output_elements();
×
643
    auto k_conv = kernel_iteration_count();
×
644

645
    auto mul_ops = symbolic::mul(output_elems, k_conv);
×
646
    auto add_ops = symbolic::mul(output_elems, symbolic::sub(k_conv, symbolic::one()));
×
647
    return symbolic::add(mul_ops, add_ops);
×
648
}
×
649

650
data_flow::PointerAccessType ConvNode::pointer_access_type(int input_idx) const {
×
651
    if (input_idx == 0) {
×
652
        return data_flow::PointerAccessMeta::create_full_write_only(symbolic::__nullptr__(), true);
×
653
    } else if (input_idx >= 1 && input_idx < inputs_.size()) {
×
654
        return data_flow::PointerAccessMeta::create_read_only(symbolic::__nullptr__(), true);
×
655
    } else {
×
656
        return TensorNode::pointer_access_type(input_idx);
×
657
    }
×
658
}
×
659

660
symbolic::Expression ConvNode::num_output_elements() const {
×
661
    // N * C_out * prod(output_spatial_dim(i))
662
    return symbolic::mul(symbolic::mul(shape_[0], output_channels_), output_spatial_volume());
×
663
}
×
664

665
symbolic::Expression ConvNode::kernel_iteration_count() const {
×
666
    // (C_in / group) * prod(kernel_shape_[i])
667
    return symbolic::mul(symbolic::div(shape_[1], group_), kernel_volume());
×
668
}
×
669

670
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
671
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
672
    nlohmann::json j;
×
673

674
    serializer::JSONSerializer serializer;
×
675
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
676
    j["group"] = serializer.expression(conv_node.group());
×
NEW
677
    j["with_bias"] = conv_node.has_bias();
×
678

679
    fill_base_values(conv_node, j);
×
680

681
    return j;
×
682
}
×
683

684
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
685
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
686
) {
×
687
    assert(j.contains("kernel_shape"));
×
688

689
    auto base = deserialize_base_values(j);
×
690

NEW
691
    auto bias_it = j.find("with_bias");
×
NEW
692
    bool with_bias = false;
×
NEW
693
    if (bias_it != j.end()) {
×
NEW
694
        with_bias = bias_it->get<bool>();
×
NEW
695
    }
×
696

697
    symbolic::Expression output_channels = symbolic::one();
×
698
    if (j.contains("output_channels")) {
×
699
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
700
    }
×
701

702
    symbolic::Expression group = symbolic::one();
×
703
    if (j.contains("group")) {
×
704
        group = symbolic::parse(j["group"].get<std::string>());
×
705
    }
×
706

707
    return builder.add_library_node<ConvNode>(
×
708
        parent,
×
709
        base.debug_info,
×
710
        base.shape,
×
711
        base.kernel_shape,
×
712
        base.strides,
×
713
        base.pads,
×
714
        base.dilations,
×
715
        output_channels,
×
716
        group,
×
NEW
717
        with_bias,
×
718
        base.quantization
×
719
    );
×
720
}
×
721

722
} // namespace tensor
723
} // namespace math
724
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc