• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4

5
#include "sdfg/analysis/analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7
#include "sdfg/types/type.h"
8

9
#include "sdfg/analysis/scope_analysis.h"
10
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
11

12
namespace sdfg {
13
namespace math {
14
namespace tensor {
15

16
ConvNode::ConvNode(
17
    size_t element_id,
18
    const DebugInfo& debug_info,
19
    const graph::Vertex vertex,
20
    data_flow::DataFlowGraph& parent,
21
    const std::vector<symbolic::Expression>& shape,
22
    const std::vector<symbolic::Expression>& kernel_shape,
23
    const std::vector<symbolic::Expression>& strides,
24
    const std::vector<symbolic::Expression>& pads,
25
    const std::vector<symbolic::Expression>& dilations,
26
    symbolic::Expression output_channels,
27
    symbolic::Expression group
28
)
UNCOV
29
    : TensorNode(
×
UNCOV
30
          element_id,
×
UNCOV
31
          debug_info,
×
UNCOV
32
          vertex,
×
UNCOV
33
          parent,
×
UNCOV
34
          LibraryNodeType_Conv,
×
UNCOV
35
          {"Y"},
×
UNCOV
36
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
×
UNCOV
37
          data_flow::ImplementationType_NONE
×
UNCOV
38
      ),
×
UNCOV
39
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations),
×
UNCOV
40
      output_channels_(output_channels), group_(group) {}
×
41

UNCOV
42
void ConvNode::validate(const Function& function) const {
×
UNCOV
43
    TensorNode::validate(function);
×
44

UNCOV
45
    auto& graph = this->get_parent();
×
46

47
    // Custom validation for ConvNode that handles optional bias input
48
    // We expect X, W as required inputs and optionally B (bias)
49

50
    // Collect all input edges by connector name
UNCOV
51
    std::map<std::string, const data_flow::Memlet*> input_edges;
×
UNCOV
52
    for (auto& iedge : graph.in_edges(*this)) {
×
UNCOV
53
        input_edges[iedge.dst_conn()] = &iedge;
×
UNCOV
54
    }
×
55

56
    // Check that required inputs X and W are present
UNCOV
57
    if (input_edges.find("X") == input_edges.end()) {
×
58
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
59
    }
×
UNCOV
60
    if (input_edges.find("W") == input_edges.end()) {
×
61
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
62
    }
×
63

64
    // Validate kernel shape is not empty
UNCOV
65
    if (kernel_shape_.empty()) {
×
66
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
67
    }
×
68

69
    // Validate strides, pads, dilations have consistent dimensions
UNCOV
70
    size_t spatial_dims = kernel_shape_.size();
×
71

UNCOV
72
    if (!strides_.empty() && strides_.size() != spatial_dims) {
×
UNCOV
73
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
×
UNCOV
74
    }
×
75

UNCOV
76
    if (!pads_.empty() && pads_.size() != 2 * spatial_dims) {
×
UNCOV
77
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
×
UNCOV
78
    }
×
79

UNCOV
80
    if (!dilations_.empty() && dilations_.size() != spatial_dims) {
×
81
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
82
    }
×
UNCOV
83
}
×
84

UNCOV
85
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
UNCOV
86
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
87

UNCOV
88
    auto& dataflow = this->get_parent();
×
UNCOV
89
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
UNCOV
90
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
UNCOV
91
    int index = parent.index(block);
×
UNCOV
92
    auto& transition = parent.at(index).second;
×
93

94
    // Get primitive type
UNCOV
95
    auto primitive_type = this->primitive_type(dataflow);
×
UNCOV
96
    types::Scalar scalar_type(primitive_type);
×
97

98
    // Get input edges
UNCOV
99
    auto in_edges = dataflow.in_edges(*this);
×
UNCOV
100
    auto in_edges_it = in_edges.begin();
×
101

UNCOV
102
    data_flow::Memlet* x_edge = nullptr;
×
UNCOV
103
    data_flow::Memlet* w_edge = nullptr;
×
UNCOV
104
    data_flow::Memlet* b_edge = nullptr;
×
105

UNCOV
106
    while (in_edges_it != in_edges.end()) {
×
UNCOV
107
        auto& edge = *in_edges_it;
×
UNCOV
108
        auto dst_conn = edge.dst_conn();
×
UNCOV
109
        if (dst_conn == "X") {
×
UNCOV
110
            x_edge = &edge;
×
UNCOV
111
        } else if (dst_conn == "W") {
×
UNCOV
112
            w_edge = &edge;
×
UNCOV
113
        } else if (dst_conn == "B") {
×
114
            b_edge = &edge;
×
115
        } else {
×
116
            throw InvalidSDFGException("ConvNode has unexpected input: " + dst_conn);
×
117
        }
×
UNCOV
118
        ++in_edges_it;
×
UNCOV
119
    }
×
120

UNCOV
121
    if (!x_edge || !w_edge) {
×
122
        throw InvalidSDFGException("ConvNode requires X and W inputs");
×
123
    }
×
124

UNCOV
125
    auto& y_edge = *dataflow.out_edges(*this).begin();
×
126

127
    // Get access nodes
UNCOV
128
    auto* x_node = static_cast<data_flow::AccessNode*>(&x_edge->src());
×
UNCOV
129
    auto* w_node = static_cast<data_flow::AccessNode*>(&w_edge->src());
×
UNCOV
130
    data_flow::AccessNode* b_node = b_edge ? static_cast<data_flow::AccessNode*>(&b_edge->src()) : nullptr;
×
UNCOV
131
    auto* y_node = static_cast<data_flow::AccessNode*>(&y_edge.dst());
×
132

133
    // Validate nodes are standalone in the block
UNCOV
134
    if (!x_node || dataflow.in_degree(*x_node) != 0 || !w_node || dataflow.in_degree(*w_node) != 0 || !y_node ||
×
UNCOV
135
        dataflow.out_degree(*y_node) != 0) {
×
136
        return false;
×
137
    }
×
138

UNCOV
139
    if (b_node && dataflow.in_degree(*b_node) != 0) {
×
140
        return false;
×
141
    }
×
142

143
    // Check that all other nodes in the block are the expected ones
UNCOV
144
    for (auto* nd : dataflow.data_nodes()) {
×
UNCOV
145
        if (nd != x_node && nd != w_node && nd != y_node && (!b_node || nd != b_node)) {
×
146
            return false; // there are other nodes we cannot handle
×
147
        }
×
UNCOV
148
    }
×
149

150
    // Support n-dimensional convolutions
UNCOV
151
    size_t spatial_dims = kernel_shape_.size();
×
152

UNCOV
153
    if (spatial_dims == 0) {
×
154
        return false; // Need at least 1 spatial dimension
×
155
    }
×
156

157
    // Get strides (default to 1 if not provided)
UNCOV
158
    std::vector<symbolic::Expression> strides_vec;
×
UNCOV
159
    for (size_t i = 0; i < spatial_dims; ++i) {
×
UNCOV
160
        if (i < strides_.size()) {
×
UNCOV
161
            strides_vec.push_back(strides_[i]);
×
UNCOV
162
        } else {
×
163
            strides_vec.push_back(symbolic::one());
×
164
        }
×
UNCOV
165
    }
×
166

167
    // Get padding (default to 0 if not provided)
168
    // Pads format: [begin_0, begin_1, ..., begin_n, end_0, end_1, ..., end_n]
UNCOV
169
    std::vector<symbolic::Expression> pads_begin_vec;
×
UNCOV
170
    std::vector<symbolic::Expression> pads_end_vec;
×
UNCOV
171
    for (size_t i = 0; i < spatial_dims; ++i) {
×
UNCOV
172
        if (i < pads_.size()) {
×
UNCOV
173
            pads_begin_vec.push_back(pads_[i]);
×
UNCOV
174
        } else {
×
175
            pads_begin_vec.push_back(symbolic::zero());
×
176
        }
×
177

UNCOV
178
        if (spatial_dims + i < pads_.size()) {
×
UNCOV
179
            pads_end_vec.push_back(pads_[spatial_dims + i]);
×
UNCOV
180
        } else {
×
181
            pads_end_vec.push_back(symbolic::zero());
×
182
        }
×
UNCOV
183
    }
×
184

185
    // Get dilations (default to 1 if not provided)
UNCOV
186
    std::vector<symbolic::Expression> dilations_vec;
×
UNCOV
187
    for (size_t i = 0; i < spatial_dims; ++i) {
×
UNCOV
188
        if (i < dilations_.size()) {
×
UNCOV
189
            dilations_vec.push_back(dilations_[i]);
×
UNCOV
190
        } else {
×
191
            dilations_vec.push_back(symbolic::one());
×
192
        }
×
UNCOV
193
    }
×
194

195
    // Get variable names
UNCOV
196
    auto& X_var = x_node->data();
×
UNCOV
197
    auto& W_var = w_node->data();
×
UNCOV
198
    auto& Y_var = y_node->data();
×
199

200
    // Use shape_ for dimensions if available
201
    // For a generic n-dimensional implementation:
202
    // Input X shape: [N, C_in, D0_in, D1_in, ..., Dn_in]
UNCOV
203
    symbolic::Expression N, C_in;
×
UNCOV
204
    std::vector<symbolic::Expression> input_spatial_dims;
×
205

UNCOV
206
    if (shape_.size() >= 2 + spatial_dims) {
×
UNCOV
207
        N = shape_[0];
×
UNCOV
208
        C_in = shape_[1];
×
UNCOV
209
        for (size_t i = 0; i < spatial_dims; ++i) {
×
UNCOV
210
            input_spatial_dims.push_back(shape_[2 + i]);
×
UNCOV
211
        }
×
UNCOV
212
    } else {
×
213
        N = symbolic::symbol(builder.find_new_name("N"));
×
214
        C_in = symbolic::symbol(builder.find_new_name("C_in"));
×
215
        for (size_t i = 0; i < spatial_dims; ++i) {
×
216
            input_spatial_dims.push_back(symbolic::symbol(builder.find_new_name("D" + std::to_string(i) + "_in")));
×
217
        }
×
218
    }
×
219

220
    // Output Channel (C_out) is passed via constructor
UNCOV
221
    auto C_out = output_channels_;
×
222

223
    // Calculate output spatial dimensions
UNCOV
224
    std::vector<symbolic::Expression> output_spatial_dims;
×
UNCOV
225
    for (size_t i = 0; i < spatial_dims; ++i) {
×
226
        // D_out = floor((D_in + pads_begin + pads_end - dilation * (kernel - 1) - 1) / stride) + 1
UNCOV
227
        auto d_in = input_spatial_dims[i];
×
UNCOV
228
        auto pad = symbolic::add(pads_begin_vec[i], pads_end_vec[i]);
×
UNCOV
229
        auto dk = symbolic::mul(dilations_vec[i], symbolic::sub(kernel_shape_[i], symbolic::one()));
×
UNCOV
230
        auto num = symbolic::sub(symbolic::add(d_in, pad), symbolic::add(dk, symbolic::one()));
×
UNCOV
231
        auto d_out = symbolic::add(symbolic::div(num, strides_vec[i]), symbolic::one());
×
UNCOV
232
        output_spatial_dims.push_back(d_out);
×
UNCOV
233
    }
×
234

235
    // Create new sequence for expansion
UNCOV
236
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
×
237

238
    // Create nested map structure for convolution
UNCOV
239
    structured_control_flow::Sequence* current_scope = &new_sequence;
×
UNCOV
240
    std::vector<symbolic::Expression> output_indices;
×
UNCOV
241
    std::vector<symbolic::Expression> output_spatial_vars;
×
242

243
    // Map over batch dimension
UNCOV
244
    std::string n_str = builder.find_new_name("n");
×
UNCOV
245
    builder.add_container(n_str, types::Scalar(types::PrimitiveType::UInt64));
×
UNCOV
246
    auto n_var = symbolic::symbol(n_str);
×
UNCOV
247
    auto& map_n = builder.add_map(
×
UNCOV
248
        *current_scope,
×
UNCOV
249
        n_var,
×
UNCOV
250
        symbolic::Lt(n_var, N),
×
UNCOV
251
        symbolic::zero(),
×
UNCOV
252
        symbolic::add(n_var, symbolic::one()),
×
UNCOV
253
        structured_control_flow::ScheduleType_Sequential::create(),
×
UNCOV
254
        {},
×
UNCOV
255
        block.debug_info()
×
UNCOV
256
    );
×
UNCOV
257
    current_scope = &map_n.root();
×
UNCOV
258
    output_indices.push_back(n_var);
×
259

260
    // Map over output channel dimension
UNCOV
261
    std::string oc_str = builder.find_new_name("oc");
×
UNCOV
262
    builder.add_container(oc_str, types::Scalar(types::PrimitiveType::UInt64));
×
UNCOV
263
    auto oc_var = symbolic::symbol(oc_str);
×
UNCOV
264
    auto& map_oc = builder.add_map(
×
UNCOV
265
        *current_scope,
×
UNCOV
266
        oc_var,
×
UNCOV
267
        symbolic::Lt(oc_var, C_out),
×
UNCOV
268
        symbolic::zero(),
×
UNCOV
269
        symbolic::add(oc_var, symbolic::one()),
×
UNCOV
270
        structured_control_flow::ScheduleType_Sequential::create(),
×
UNCOV
271
        {},
×
UNCOV
272
        block.debug_info()
×
UNCOV
273
    );
×
UNCOV
274
    current_scope = &map_oc.root();
×
UNCOV
275
    output_indices.push_back(oc_var);
×
276

277
    // Map over each output spatial dimension dynamically
UNCOV
278
    for (size_t i = 0; i < spatial_dims; ++i) {
×
UNCOV
279
        std::string od_str = builder.find_new_name("od" + std::to_string(i));
×
UNCOV
280
        builder.add_container(od_str, types::Scalar(types::PrimitiveType::UInt64));
×
UNCOV
281
        auto od_var = symbolic::symbol(od_str);
×
UNCOV
282
        auto& map_od = builder.add_map(
×
UNCOV
283
            *current_scope,
×
UNCOV
284
            od_var,
×
UNCOV
285
            symbolic::Lt(od_var, output_spatial_dims[i]),
×
UNCOV
286
            symbolic::zero(),
×
UNCOV
287
            symbolic::add(od_var, symbolic::one()),
×
UNCOV
288
            structured_control_flow::ScheduleType_Sequential::create(),
×
UNCOV
289
            {},
×
UNCOV
290
            block.debug_info()
×
UNCOV
291
        );
×
UNCOV
292
        current_scope = &map_od.root();
×
UNCOV
293
        output_indices.push_back(od_var);
×
UNCOV
294
        output_spatial_vars.push_back(od_var);
×
UNCOV
295
    }
×
296

297
    // Create accumulator variable for the sum
UNCOV
298
    std::string accum_var = builder.find_new_name("_conv_accum");
×
UNCOV
299
    builder.add_container(accum_var, scalar_type);
×
300

301
    // Initialize accumulator to 0
UNCOV
302
    auto& init_block = builder.add_block(*current_scope, {}, block.debug_info());
×
UNCOV
303
    auto& accum_init = builder.add_access(init_block, accum_var, block.debug_info());
×
UNCOV
304
    auto& zero_const = builder.add_constant(init_block, "0.0", scalar_type, block.debug_info());
×
UNCOV
305
    auto& init_tasklet = builder.add_tasklet(init_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
×
UNCOV
306
    builder.add_computational_memlet(init_block, zero_const, init_tasklet, "_in", {}, scalar_type, block.debug_info());
×
UNCOV
307
    builder.add_computational_memlet(init_block, init_tasklet, "_out", accum_init, {}, scalar_type, block.debug_info());
×
308

309
    // Create nested for loops for input channels and kernel dimensions
310
    // For loop over input channels
UNCOV
311
    std::string ic_str = builder.find_new_name("ic");
×
UNCOV
312
    builder.add_container(ic_str, types::Scalar(types::PrimitiveType::UInt64));
×
UNCOV
313
    auto ic_var = symbolic::symbol(ic_str);
×
UNCOV
314
    auto& for_ic = builder.add_for(
×
UNCOV
315
        *current_scope,
×
UNCOV
316
        ic_var,
×
UNCOV
317
        symbolic::Lt(ic_var, C_in),
×
UNCOV
318
        symbolic::zero(),
×
UNCOV
319
        symbolic::add(ic_var, symbolic::one()),
×
UNCOV
320
        {},
×
UNCOV
321
        block.debug_info()
×
UNCOV
322
    );
×
UNCOV
323
    auto* loop_scope = &for_ic.root();
×
324

325
    // For loops over each kernel spatial dimension
UNCOV
326
    std::vector<symbolic::Expression> kernel_vars;
×
UNCOV
327
    for (size_t i = 0; i < spatial_dims; ++i) {
×
UNCOV
328
        std::string k_str = builder.find_new_name("k" + std::to_string(i));
×
UNCOV
329
        builder.add_container(k_str, types::Scalar(types::PrimitiveType::UInt64));
×
UNCOV
330
        auto k_var = symbolic::symbol(k_str);
×
UNCOV
331
        auto& for_k = builder.add_for(
×
UNCOV
332
            *loop_scope,
×
UNCOV
333
            k_var,
×
UNCOV
334
            symbolic::Lt(k_var, kernel_shape_[i]),
×
UNCOV
335
            symbolic::zero(),
×
UNCOV
336
            symbolic::add(k_var, symbolic::one()),
×
UNCOV
337
            {},
×
UNCOV
338
            block.debug_info()
×
UNCOV
339
        );
×
UNCOV
340
        loop_scope = &for_k.root();
×
UNCOV
341
        kernel_vars.push_back(k_var);
×
UNCOV
342
    }
×
343

344
    // Compute indices for input and weight access
345
    // Input index: [n, ic, od0 * stride0 - pad0 + k0 * dilation0, ...]
346
    // Note: taking dilation into account for input index calculation
UNCOV
347
    std::vector<symbolic::Expression> input_spatial_indices;
×
UNCOV
348
    for (size_t i = 0; i < spatial_dims; ++i) {
×
UNCOV
349
        auto k_dilated = symbolic::mul(kernel_vars[i], dilations_vec[i]);
×
UNCOV
350
        auto input_idx = symbolic::
×
UNCOV
351
            add(symbolic::sub(symbolic::mul(output_spatial_vars[i], strides_vec[i]), pads_begin_vec[i]), k_dilated);
×
UNCOV
352
        input_spatial_indices.push_back(input_idx);
×
UNCOV
353
    }
×
354

355
    // Create computation block
UNCOV
356
    auto& comp_block = builder.add_block(*loop_scope, {}, block.debug_info());
×
357

358
    // Access input X[n, ic, input_spatial_indices...]
UNCOV
359
    auto& x_access = builder.add_access(comp_block, X_var, x_node->debug_info());
×
360
    // Access weight W[oc, ic, k0, k1, ...]
UNCOV
361
    auto& w_access = builder.add_access(comp_block, W_var, w_node->debug_info());
×
362
    // Access accumulator
UNCOV
363
    auto& accum_read = builder.add_access(comp_block, accum_var, block.debug_info());
×
UNCOV
364
    auto& accum_write = builder.add_access(comp_block, accum_var, block.debug_info());
×
365

366
    // Create FMA tasklet: accum = accum + x * w
UNCOV
367
    auto& fma_tasklet =
×
UNCOV
368
        builder.add_tasklet(comp_block, data_flow::fp_fma, "_out", {"_in1", "_in2", "_in3"}, block.debug_info());
×
369

370

371
    // Calculate shapes for
372
    // X shape: [N, C_in, D0, D1...]
UNCOV
373
    std::vector<symbolic::Expression> x_shape_vec = {N, C_in};
×
UNCOV
374
    x_shape_vec.insert(x_shape_vec.end(), input_spatial_dims.begin(), input_spatial_dims.end());
×
375

376
    // W shape: [C_out, C_in/group, k0, k1...]
UNCOV
377
    std::vector<symbolic::Expression> w_shape_vec = {C_out, symbolic::div(C_in, group_)};
×
UNCOV
378
    w_shape_vec.insert(w_shape_vec.end(), kernel_shape_.begin(), kernel_shape_.end());
×
379

380
    // Connect edges with subsets
UNCOV
381
    std::vector<symbolic::Expression> x_indices_vec = {n_var, ic_var};
×
UNCOV
382
    x_indices_vec.insert(x_indices_vec.end(), input_spatial_indices.begin(), input_spatial_indices.end());
×
383

UNCOV
384
    std::vector<symbolic::Expression> w_indices_vec = {oc_var, ic_var}; // Assuming group=1 for now for simplicity of
×
385
                                                                        // indices
386
    // TODO: Handle groups properly in indices if needed, but for standard conv:
UNCOV
387
    w_indices_vec.insert(w_indices_vec.end(), kernel_vars.begin(), kernel_vars.end());
×
388

NEW
389
    data_flow::Subset x_subset(x_indices_vec);
×
NEW
390
    data_flow::Subset w_subset(w_indices_vec);
×
391

UNCOV
392
    builder.add_computational_memlet(
×
UNCOV
393
        comp_block, x_access, fma_tasklet, "_in1", x_subset, x_edge->base_type(), x_edge->debug_info()
×
UNCOV
394
    );
×
UNCOV
395
    builder.add_computational_memlet(
×
UNCOV
396
        comp_block, w_access, fma_tasklet, "_in2", w_subset, w_edge->base_type(), w_edge->debug_info()
×
UNCOV
397
    );
×
UNCOV
398
    builder.add_computational_memlet(comp_block, accum_read, fma_tasklet, "_in3", {}, scalar_type, block.debug_info());
×
UNCOV
399
    builder.add_computational_memlet(comp_block, fma_tasklet, "_out", accum_write, {}, scalar_type, block.debug_info());
×
400

401
    // After all loops, write accumulated result to output (with optional bias)
UNCOV
402
    auto& output_block = builder.add_block(*current_scope, {}, block.debug_info());
×
UNCOV
403
    auto& accum_final = builder.add_access(output_block, accum_var, block.debug_info());
×
UNCOV
404
    auto& y_access = builder.add_access(output_block, Y_var, y_node->debug_info());
×
405

406
    // Y shape: [N, C_out, D0_out, ...]
UNCOV
407
    std::vector<symbolic::Expression> y_shape_vec = {N, C_out};
×
UNCOV
408
    y_shape_vec.insert(y_shape_vec.end(), output_spatial_dims.begin(), output_spatial_dims.end());
×
409

NEW
410
    data_flow::Subset y_subset(output_indices);
×
411

UNCOV
412
    if (b_node) {
×
413
        // Add bias: output = accum + bias[oc]
414
        auto& b_access = builder.add_access(output_block, b_node->data(), b_node->debug_info());
×
415
        auto& add_tasklet =
×
416
            builder.add_tasklet(output_block, data_flow::fp_add, "_out", {"_in1", "_in2"}, block.debug_info());
×
417

418
        builder
×
419
            .add_computational_memlet(output_block, accum_final, add_tasklet, "_in1", {}, scalar_type, block.debug_info());
×
420
        builder.add_computational_memlet(
×
421
            output_block, b_access, add_tasklet, "_in2", {oc_var}, b_edge->base_type(), b_edge->debug_info()
×
422
        );
×
423
        builder.add_computational_memlet(
×
424
            output_block, add_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
×
425
        );
×
UNCOV
426
    } else {
×
427
        // No bias: output = accum
UNCOV
428
        auto& assign_tasklet =
×
UNCOV
429
            builder.add_tasklet(output_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
×
430

UNCOV
431
        builder.add_computational_memlet(
×
UNCOV
432
            output_block, accum_final, assign_tasklet, "_in", {}, scalar_type, block.debug_info()
×
UNCOV
433
        );
×
UNCOV
434
        builder.add_computational_memlet(
×
UNCOV
435
            output_block, assign_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
×
UNCOV
436
        );
×
UNCOV
437
    }
×
438

439
    // Clean up the original block
UNCOV
440
    builder.remove_memlet(block, *x_edge);
×
UNCOV
441
    builder.remove_memlet(block, *w_edge);
×
UNCOV
442
    if (b_edge) {
×
443
        builder.remove_memlet(block, *b_edge);
×
444
        builder.remove_node(block, *b_node);
×
445
    }
×
UNCOV
446
    builder.remove_memlet(block, y_edge);
×
UNCOV
447
    builder.remove_node(block, *x_node);
×
UNCOV
448
    builder.remove_node(block, *w_node);
×
UNCOV
449
    builder.remove_node(block, *y_node);
×
UNCOV
450
    builder.remove_node(block, *this);
×
UNCOV
451
    builder.remove_child(parent, index + 1);
×
452

UNCOV
453
    return true;
×
UNCOV
454
}
×
455

456
symbolic::SymbolSet ConvNode::symbols() const {
×
457
    symbolic::SymbolSet syms;
×
458

459
    for (auto& expr : shape_) {
×
460
        for (auto& atom : symbolic::atoms(expr)) {
×
461
            syms.insert(atom);
×
462
        }
×
463
    }
×
464
    for (auto& expr : kernel_shape_) {
×
465
        for (auto& atom : symbolic::atoms(expr)) {
×
466
            syms.insert(atom);
×
467
        }
×
468
    }
×
469
    for (auto& expr : strides_) {
×
470
        for (auto& atom : symbolic::atoms(expr)) {
×
471
            syms.insert(atom);
×
472
        }
×
473
    }
×
474
    for (auto& expr : pads_) {
×
475
        for (auto& atom : symbolic::atoms(expr)) {
×
476
            syms.insert(atom);
×
477
        }
×
478
    }
×
479
    for (auto& expr : dilations_) {
×
480
        for (auto& atom : symbolic::atoms(expr)) {
×
481
            syms.insert(atom);
×
482
        }
×
483
    }
×
484
    for (auto& atom : symbolic::atoms(output_channels_)) {
×
485
        syms.insert(atom);
×
486
    }
×
487
    for (auto& atom : symbolic::atoms(group_)) {
×
488
        syms.insert(atom);
×
489
    }
×
490

491
    return syms;
×
492
}
×
493

494
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
495
    for (auto& expr : shape_) {
×
496
        expr = symbolic::subs(expr, old_expression, new_expression);
×
497
    }
×
498
    for (auto& expr : kernel_shape_) {
×
499
        expr = symbolic::subs(expr, old_expression, new_expression);
×
500
    }
×
501
    for (auto& expr : strides_) {
×
502
        expr = symbolic::subs(expr, old_expression, new_expression);
×
503
    }
×
504
    for (auto& expr : pads_) {
×
505
        expr = symbolic::subs(expr, old_expression, new_expression);
×
506
    }
×
507
    for (auto& expr : dilations_) {
×
508
        expr = symbolic::subs(expr, old_expression, new_expression);
×
509
    }
×
510
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
511
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
512
}
×
513

514
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
UNCOV
515
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
UNCOV
516
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
×
UNCOV
517
        element_id,
×
UNCOV
518
        this->debug_info(),
×
UNCOV
519
        vertex,
×
UNCOV
520
        parent,
×
UNCOV
521
        shape_,
×
UNCOV
522
        kernel_shape_,
×
UNCOV
523
        strides_,
×
UNCOV
524
        pads_,
×
UNCOV
525
        dilations_,
×
UNCOV
526
        output_channels_,
×
UNCOV
527
        group_
×
UNCOV
528
    ));
×
UNCOV
529
}
×
530

531
std::string ConvNode::toStr() const {
×
532
    std::string result = "Conv(shape=[";
×
533
    for (size_t i = 0; i < shape_.size(); ++i) {
×
534
        if (i > 0) result += ", ";
×
535
        result += shape_[i]->__str__();
×
536
    }
×
537
    result += "], kernel_shape=[";
×
538
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
539
        if (i > 0) result += ", ";
×
540
        result += kernel_shape_[i]->__str__();
×
541
    }
×
542
    result += "], strides=[";
×
543
    for (size_t i = 0; i < strides_.size(); ++i) {
×
544
        if (i > 0) result += ", ";
×
545
        result += strides_[i]->__str__();
×
546
    }
×
547
    result += "], output_channels=" + output_channels_->__str__();
×
548
    result += ", group=" + group_->__str__() + ")";
×
549
    return result;
×
550
}
×
551

552
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
553
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
554
    nlohmann::json j;
×
555

556
    j["code"] = conv_node.code().value();
×
557

558
    serializer::JSONSerializer serializer;
×
559

560
    j["shape"] = nlohmann::json::array();
×
561
    for (auto& dim : conv_node.shape()) {
×
562
        j["shape"].push_back(serializer.expression(dim));
×
563
    }
×
564

565
    j["kernel_shape"] = nlohmann::json::array();
×
566
    for (auto& dim : conv_node.kernel_shape()) {
×
567
        j["kernel_shape"].push_back(serializer.expression(dim));
×
568
    }
×
569

570
    j["strides"] = nlohmann::json::array();
×
571
    for (auto& stride : conv_node.strides()) {
×
572
        j["strides"].push_back(serializer.expression(stride));
×
573
    }
×
574

575
    j["pads"] = nlohmann::json::array();
×
576
    for (auto& pad : conv_node.pads()) {
×
577
        j["pads"].push_back(serializer.expression(pad));
×
578
    }
×
579

580
    j["dilations"] = nlohmann::json::array();
×
581
    for (auto& dilation : conv_node.dilations()) {
×
582
        j["dilations"].push_back(serializer.expression(dilation));
×
583
    }
×
584

585
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
586
    j["group"] = serializer.expression(conv_node.group());
×
587

588
    return j;
×
589
}
×
590

591
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
592
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
593
) {
×
594
    assert(j.contains("element_id"));
×
595
    assert(j.contains("code"));
×
596
    assert(j.contains("debug_info"));
×
597
    assert(j.contains("kernel_shape"));
×
598

599
    std::vector<symbolic::Expression> shape;
×
600
    if (j.contains("shape")) {
×
601
        for (const auto& dim : j["shape"]) {
×
602
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
603
        }
×
604
    }
×
605

606
    std::vector<symbolic::Expression> kernel_shape;
×
607
    for (const auto& dim : j["kernel_shape"]) {
×
608
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
609
    }
×
610

611
    std::vector<symbolic::Expression> strides;
×
612
    if (j.contains("strides")) {
×
613
        for (const auto& stride : j["strides"]) {
×
614
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
615
        }
×
616
    }
×
617

618
    std::vector<symbolic::Expression> pads;
×
619
    if (j.contains("pads")) {
×
620
        for (const auto& pad : j["pads"]) {
×
621
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
622
        }
×
623
    }
×
624

625
    std::vector<symbolic::Expression> dilations;
×
626
    if (j.contains("dilations")) {
×
627
        for (const auto& dilation : j["dilations"]) {
×
628
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
629
        }
×
630
    }
×
631

632
    symbolic::Expression output_channels = symbolic::one();
×
633
    if (j.contains("output_channels")) {
×
634
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
635
    }
×
636

637
    symbolic::Expression group = symbolic::one();
×
638
    if (j.contains("group")) {
×
639
        group = symbolic::parse(j["group"].get<std::string>());
×
640
    }
×
641

642
    sdfg::serializer::JSONSerializer serializer;
×
643
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
644

645
    return builder.add_library_node<
×
646
        ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, output_channels, group);
×
647
}
×
648

649
} // namespace tensor
650
} // namespace math
651
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc