• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20896124135

11 Jan 2026 01:43PM UTC coverage: 62.33% (-0.07%) from 62.402%
20896124135

push

github

web-flow
Merge pull request #423 from daisytuner/copilot/extend-tensor-nodes-conv

Add ConvNode compatible with ONNX Conv operator with n-dimensional expansion and custom validation

329 of 554 new or added lines in 21 files covered. (59.39%)

2 existing lines in 2 files now uncovered.

15413 of 24728 relevant lines covered (62.33%)

88.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

59.14
/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4

5
#include "sdfg/analysis/analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7
#include "sdfg/types/type.h"
8

9
#include "sdfg/analysis/scope_analysis.h"
10
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
11

12
namespace sdfg {
13
namespace math {
14
namespace tensor {
15

16
ConvNode::ConvNode(
17
    size_t element_id,
18
    const DebugInfo& debug_info,
19
    const graph::Vertex vertex,
20
    data_flow::DataFlowGraph& parent,
21
    const std::vector<symbolic::Expression>& shape,
22
    const std::vector<symbolic::Expression>& kernel_shape,
23
    const std::vector<symbolic::Expression>& strides,
24
    const std::vector<symbolic::Expression>& pads,
25
    const std::vector<symbolic::Expression>& dilations,
26
    symbolic::Expression output_channels,
27
    symbolic::Expression group
28
)
29
    : TensorNode(
22✔
30
          element_id,
22✔
31
          debug_info,
22✔
32
          vertex,
22✔
33
          parent,
22✔
34
          LibraryNodeType_Conv,
22✔
35
          {"Y"},
22✔
36
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
22✔
37
          data_flow::ImplementationType_NONE
22✔
38
      ),
22✔
39
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations),
22✔
40
      output_channels_(output_channels), group_(group) {}
22✔
41

42
void ConvNode::validate(const Function& function) const {
17✔
43
    TensorNode::validate(function);
17✔
44

45
    auto& graph = this->get_parent();
17✔
46

47
    // Custom validation for ConvNode that handles optional bias input
48
    // We expect X, W as required inputs and optionally B (bias)
49

50
    // Collect all input edges by connector name
51
    std::map<std::string, const data_flow::Memlet*> input_edges;
17✔
52
    for (auto& iedge : graph.in_edges(*this)) {
34✔
53
        input_edges[iedge.dst_conn()] = &iedge;
34✔
54
    }
34✔
55

56
    // Check that required inputs X and W are present
57
    if (input_edges.find("X") == input_edges.end()) {
17✔
NEW
58
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
NEW
59
    }
×
60
    if (input_edges.find("W") == input_edges.end()) {
17✔
NEW
61
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
NEW
62
    }
×
63

64
    // Validate kernel shape is not empty
65
    if (kernel_shape_.empty()) {
17✔
NEW
66
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
NEW
67
    }
×
68

69
    // Validate strides, pads, dilations have consistent dimensions
70
    size_t spatial_dims = kernel_shape_.size();
17✔
71

72
    if (!strides_.empty() && strides_.size() != spatial_dims) {
17✔
73
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
74
    }
1✔
75

76
    if (!pads_.empty() && pads_.size() != 2 * spatial_dims) {
16✔
77
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
78
    }
1✔
79

80
    if (!dilations_.empty() && dilations_.size() != spatial_dims) {
15✔
NEW
81
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
NEW
82
    }
×
83
}
15✔
84

85
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
86
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
87

88
    auto& dataflow = this->get_parent();
5✔
89
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
5✔
90
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
5✔
91
    int index = parent.index(block);
5✔
92
    auto& transition = parent.at(index).second;
5✔
93

94
    // Get primitive type
95
    auto primitive_type = this->primitive_type(dataflow);
5✔
96
    types::Scalar scalar_type(primitive_type);
5✔
97

98
    // Get input edges
99
    auto in_edges = dataflow.in_edges(*this);
5✔
100
    auto in_edges_it = in_edges.begin();
5✔
101

102
    data_flow::Memlet* x_edge = nullptr;
5✔
103
    data_flow::Memlet* w_edge = nullptr;
5✔
104
    data_flow::Memlet* b_edge = nullptr;
5✔
105

106
    while (in_edges_it != in_edges.end()) {
15✔
107
        auto& edge = *in_edges_it;
10✔
108
        auto dst_conn = edge.dst_conn();
10✔
109
        if (dst_conn == "X") {
10✔
110
            x_edge = &edge;
5✔
111
        } else if (dst_conn == "W") {
5✔
112
            w_edge = &edge;
5✔
113
        } else if (dst_conn == "B") {
5✔
NEW
114
            b_edge = &edge;
×
NEW
115
        } else {
×
NEW
116
            throw InvalidSDFGException("ConvNode has unexpected input: " + dst_conn);
×
NEW
117
        }
×
118
        ++in_edges_it;
10✔
119
    }
10✔
120

121
    if (!x_edge || !w_edge) {
5✔
NEW
122
        throw InvalidSDFGException("ConvNode requires X and W inputs");
×
NEW
123
    }
×
124

125
    auto& y_edge = *dataflow.out_edges(*this).begin();
5✔
126

127
    // Get access nodes
128
    auto* x_node = static_cast<data_flow::AccessNode*>(&x_edge->src());
5✔
129
    auto* w_node = static_cast<data_flow::AccessNode*>(&w_edge->src());
5✔
130
    data_flow::AccessNode* b_node = b_edge ? static_cast<data_flow::AccessNode*>(&b_edge->src()) : nullptr;
5✔
131
    auto* y_node = static_cast<data_flow::AccessNode*>(&y_edge.dst());
5✔
132

133
    // Validate nodes are standalone in the block
134
    if (!x_node || dataflow.in_degree(*x_node) != 0 || !w_node || dataflow.in_degree(*w_node) != 0 || !y_node ||
5✔
135
        dataflow.out_degree(*y_node) != 0) {
5✔
NEW
136
        return false;
×
NEW
137
    }
×
138

139
    if (b_node && dataflow.in_degree(*b_node) != 0) {
5✔
NEW
140
        return false;
×
NEW
141
    }
×
142

143
    // Check that all other nodes in the block are the expected ones
144
    for (auto* nd : dataflow.data_nodes()) {
15✔
145
        if (nd != x_node && nd != w_node && nd != y_node && (!b_node || nd != b_node)) {
15✔
NEW
146
            return false; // there are other nodes we cannot handle
×
NEW
147
        }
×
148
    }
15✔
149

150
    // Support n-dimensional convolutions
151
    size_t spatial_dims = kernel_shape_.size();
5✔
152

153
    if (spatial_dims == 0) {
5✔
NEW
154
        return false; // Need at least 1 spatial dimension
×
NEW
155
    }
×
156

157
    // Get strides (default to 1 if not provided)
158
    std::vector<symbolic::Expression> strides_vec;
5✔
159
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
160
        if (i < strides_.size()) {
10✔
161
            strides_vec.push_back(strides_[i]);
10✔
162
        } else {
10✔
NEW
163
            strides_vec.push_back(symbolic::one());
×
NEW
164
        }
×
165
    }
10✔
166

167
    // Get padding (default to 0 if not provided)
168
    // Pads format: [begin_0, begin_1, ..., begin_n, end_0, end_1, ..., end_n]
169
    std::vector<symbolic::Expression> pads_begin_vec;
5✔
170
    std::vector<symbolic::Expression> pads_end_vec;
5✔
171
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
172
        if (i < pads_.size()) {
10✔
173
            pads_begin_vec.push_back(pads_[i]);
10✔
174
        } else {
10✔
NEW
175
            pads_begin_vec.push_back(symbolic::zero());
×
NEW
176
        }
×
177

178
        if (spatial_dims + i < pads_.size()) {
10✔
179
            pads_end_vec.push_back(pads_[spatial_dims + i]);
10✔
180
        } else {
10✔
NEW
181
            pads_end_vec.push_back(symbolic::zero());
×
NEW
182
        }
×
183
    }
10✔
184

185
    // Get dilations (default to 1 if not provided)
186
    std::vector<symbolic::Expression> dilations_vec;
5✔
187
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
188
        if (i < dilations_.size()) {
10✔
189
            dilations_vec.push_back(dilations_[i]);
10✔
190
        } else {
10✔
NEW
191
            dilations_vec.push_back(symbolic::one());
×
NEW
192
        }
×
193
    }
10✔
194

195
    // Get variable names
196
    auto& X_var = x_node->data();
5✔
197
    auto& W_var = w_node->data();
5✔
198
    auto& Y_var = y_node->data();
5✔
199

200
    // Use shape_ for dimensions if available
201
    // For a generic n-dimensional implementation:
202
    // Input X shape: [N, C_in, D0_in, D1_in, ..., Dn_in]
203
    symbolic::Expression N, C_in;
5✔
204
    std::vector<symbolic::Expression> input_spatial_dims;
5✔
205

206
    if (shape_.size() >= 2 + spatial_dims) {
5✔
207
        N = shape_[0];
5✔
208
        C_in = shape_[1];
5✔
209
        for (size_t i = 0; i < spatial_dims; ++i) {
15✔
210
            input_spatial_dims.push_back(shape_[2 + i]);
10✔
211
        }
10✔
212
    } else {
5✔
NEW
213
        N = symbolic::symbol(builder.find_new_name("N"));
×
NEW
214
        C_in = symbolic::symbol(builder.find_new_name("C_in"));
×
NEW
215
        for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
216
            input_spatial_dims.push_back(symbolic::symbol(builder.find_new_name("D" + std::to_string(i) + "_in")));
×
NEW
217
        }
×
NEW
218
    }
×
219

220
    // Output Channel (C_out) is passed via constructor
221
    auto C_out = output_channels_;
5✔
222

223
    // Calculate output spatial dimensions
224
    std::vector<symbolic::Expression> output_spatial_dims;
5✔
225
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
226
        // D_out = floor((D_in + pads_begin + pads_end - dilation * (kernel - 1) - 1) / stride) + 1
227
        auto d_in = input_spatial_dims[i];
10✔
228
        auto pad = symbolic::add(pads_begin_vec[i], pads_end_vec[i]);
10✔
229
        auto dk = symbolic::mul(dilations_vec[i], symbolic::sub(kernel_shape_[i], symbolic::one()));
10✔
230
        auto num = symbolic::sub(symbolic::add(d_in, pad), symbolic::add(dk, symbolic::one()));
10✔
231
        auto d_out = symbolic::add(symbolic::div(num, strides_vec[i]), symbolic::one());
10✔
232
        output_spatial_dims.push_back(d_out);
10✔
233
    }
10✔
234

235
    // Create new sequence for expansion
236
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
5✔
237

238
    // Create nested map structure for convolution
239
    structured_control_flow::Sequence* current_scope = &new_sequence;
5✔
240
    std::vector<symbolic::Expression> output_indices;
5✔
241
    std::vector<symbolic::Expression> output_spatial_vars;
5✔
242

243
    // Map over batch dimension
244
    std::string n_str = builder.find_new_name("n");
5✔
245
    builder.add_container(n_str, types::Scalar(types::PrimitiveType::UInt64));
5✔
246
    auto n_var = symbolic::symbol(n_str);
5✔
247
    auto& map_n = builder.add_map(
5✔
248
        *current_scope,
5✔
249
        n_var,
5✔
250
        symbolic::Lt(n_var, N),
5✔
251
        symbolic::zero(),
5✔
252
        symbolic::add(n_var, symbolic::one()),
5✔
253
        structured_control_flow::ScheduleType_Sequential::create(),
5✔
254
        {},
5✔
255
        block.debug_info()
5✔
256
    );
5✔
257
    current_scope = &map_n.root();
5✔
258
    output_indices.push_back(n_var);
5✔
259

260
    // Map over output channel dimension
261
    std::string oc_str = builder.find_new_name("oc");
5✔
262
    builder.add_container(oc_str, types::Scalar(types::PrimitiveType::UInt64));
5✔
263
    auto oc_var = symbolic::symbol(oc_str);
5✔
264
    auto& map_oc = builder.add_map(
5✔
265
        *current_scope,
5✔
266
        oc_var,
5✔
267
        symbolic::Lt(oc_var, C_out),
5✔
268
        symbolic::zero(),
5✔
269
        symbolic::add(oc_var, symbolic::one()),
5✔
270
        structured_control_flow::ScheduleType_Sequential::create(),
5✔
271
        {},
5✔
272
        block.debug_info()
5✔
273
    );
5✔
274
    current_scope = &map_oc.root();
5✔
275
    output_indices.push_back(oc_var);
5✔
276

277
    // Map over each output spatial dimension dynamically
278
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
279
        std::string od_str = builder.find_new_name("od" + std::to_string(i));
10✔
280
        builder.add_container(od_str, types::Scalar(types::PrimitiveType::UInt64));
10✔
281
        auto od_var = symbolic::symbol(od_str);
10✔
282
        auto& map_od = builder.add_map(
10✔
283
            *current_scope,
10✔
284
            od_var,
10✔
285
            symbolic::Lt(od_var, output_spatial_dims[i]),
10✔
286
            symbolic::zero(),
10✔
287
            symbolic::add(od_var, symbolic::one()),
10✔
288
            structured_control_flow::ScheduleType_Sequential::create(),
10✔
289
            {},
10✔
290
            block.debug_info()
10✔
291
        );
10✔
292
        current_scope = &map_od.root();
10✔
293
        output_indices.push_back(od_var);
10✔
294
        output_spatial_vars.push_back(od_var);
10✔
295
    }
10✔
296

297
    // Create accumulator variable for the sum
298
    std::string accum_var = builder.find_new_name("_conv_accum");
5✔
299
    builder.add_container(accum_var, scalar_type);
5✔
300

301
    // Initialize accumulator to 0
302
    auto& init_block = builder.add_block(*current_scope, {}, block.debug_info());
5✔
303
    auto& accum_init = builder.add_access(init_block, accum_var, block.debug_info());
5✔
304
    auto& zero_const = builder.add_constant(init_block, "0.0", scalar_type, block.debug_info());
5✔
305
    auto& init_tasklet = builder.add_tasklet(init_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
5✔
306
    builder.add_computational_memlet(init_block, zero_const, init_tasklet, "_in", {}, scalar_type, block.debug_info());
5✔
307
    builder.add_computational_memlet(init_block, init_tasklet, "_out", accum_init, {}, scalar_type, block.debug_info());
5✔
308

309
    // Create nested for loops for input channels and kernel dimensions
310
    // For loop over input channels
311
    std::string ic_str = builder.find_new_name("ic");
5✔
312
    builder.add_container(ic_str, types::Scalar(types::PrimitiveType::UInt64));
5✔
313
    auto ic_var = symbolic::symbol(ic_str);
5✔
314
    auto& for_ic = builder.add_for(
5✔
315
        *current_scope,
5✔
316
        ic_var,
5✔
317
        symbolic::Lt(ic_var, C_in),
5✔
318
        symbolic::zero(),
5✔
319
        symbolic::add(ic_var, symbolic::one()),
5✔
320
        {},
5✔
321
        block.debug_info()
5✔
322
    );
5✔
323
    auto* loop_scope = &for_ic.root();
5✔
324

325
    // For loops over each kernel spatial dimension
326
    std::vector<symbolic::Expression> kernel_vars;
5✔
327
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
328
        std::string k_str = builder.find_new_name("k" + std::to_string(i));
10✔
329
        builder.add_container(k_str, types::Scalar(types::PrimitiveType::UInt64));
10✔
330
        auto k_var = symbolic::symbol(k_str);
10✔
331
        auto& for_k = builder.add_for(
10✔
332
            *loop_scope,
10✔
333
            k_var,
10✔
334
            symbolic::Lt(k_var, kernel_shape_[i]),
10✔
335
            symbolic::zero(),
10✔
336
            symbolic::add(k_var, symbolic::one()),
10✔
337
            {},
10✔
338
            block.debug_info()
10✔
339
        );
10✔
340
        loop_scope = &for_k.root();
10✔
341
        kernel_vars.push_back(k_var);
10✔
342
    }
10✔
343

344
    // Compute indices for input and weight access
345
    // Input index: [n, ic, od0 * stride0 - pad0 + k0 * dilation0, ...]
346
    // Note: taking dilation into account for input index calculation
347
    std::vector<symbolic::Expression> input_spatial_indices;
5✔
348
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
349
        auto k_dilated = symbolic::mul(kernel_vars[i], dilations_vec[i]);
10✔
350
        auto input_idx = symbolic::
10✔
351
            add(symbolic::sub(symbolic::mul(output_spatial_vars[i], strides_vec[i]), pads_begin_vec[i]), k_dilated);
10✔
352
        input_spatial_indices.push_back(input_idx);
10✔
353
    }
10✔
354

355
    // Create computation block
356
    auto& comp_block = builder.add_block(*loop_scope, {}, block.debug_info());
5✔
357

358
    // Access input X[n, ic, input_spatial_indices...]
359
    auto& x_access = builder.add_access(comp_block, X_var, x_node->debug_info());
5✔
360
    // Access weight W[oc, ic, k0, k1, ...]
361
    auto& w_access = builder.add_access(comp_block, W_var, w_node->debug_info());
5✔
362
    // Access accumulator
363
    auto& accum_read = builder.add_access(comp_block, accum_var, block.debug_info());
5✔
364
    auto& accum_write = builder.add_access(comp_block, accum_var, block.debug_info());
5✔
365

366
    // Create FMA tasklet: accum = accum + x * w
367
    auto& fma_tasklet =
5✔
368
        builder.add_tasklet(comp_block, data_flow::fp_fma, "_out", {"_in1", "_in2", "_in3"}, block.debug_info());
5✔
369

370
    // Linearization helper
371
    auto linearize = [&](const std::vector<symbolic::Expression>& indices,
5✔
372
                         const std::vector<symbolic::Expression>& shape) -> symbolic::Expression {
15✔
373
        symbolic::Expression idx = symbolic::zero();
15✔
374
        symbolic::Expression stride = symbolic::one();
15✔
375
        for (int i = shape.size() - 1; i >= 0; --i) {
75✔
376
            idx = symbolic::add(idx, symbolic::mul(indices[i], stride));
60✔
377
            stride = symbolic::mul(stride, shape[i]);
60✔
378
        }
60✔
379
        return idx;
15✔
380
    };
15✔
381

382
    // Calculate shapes for linearization
383
    // X shape: [N, C_in, D0, D1...]
384
    std::vector<symbolic::Expression> x_shape_vec = {N, C_in};
5✔
385
    x_shape_vec.insert(x_shape_vec.end(), input_spatial_dims.begin(), input_spatial_dims.end());
5✔
386

387
    // W shape: [C_out, C_in/group, k0, k1...]
388
    std::vector<symbolic::Expression> w_shape_vec = {C_out, symbolic::div(C_in, group_)};
5✔
389
    w_shape_vec.insert(w_shape_vec.end(), kernel_shape_.begin(), kernel_shape_.end());
5✔
390

391
    // Connect edges with linearized subsets
392
    std::vector<symbolic::Expression> x_indices_vec = {n_var, ic_var};
5✔
393
    x_indices_vec.insert(x_indices_vec.end(), input_spatial_indices.begin(), input_spatial_indices.end());
5✔
394

395
    std::vector<symbolic::Expression> w_indices_vec = {oc_var, ic_var}; // Assuming group=1 for now for simplicity of
5✔
396
                                                                        // indices
397
    // TODO: Handle groups properly in indices if needed, but for standard conv:
398
    w_indices_vec.insert(w_indices_vec.end(), kernel_vars.begin(), kernel_vars.end());
5✔
399

400
    data_flow::Subset x_subset({linearize(x_indices_vec, x_shape_vec)});
5✔
401
    data_flow::Subset w_subset({linearize(w_indices_vec, w_shape_vec)});
5✔
402

403
    builder.add_computational_memlet(
5✔
404
        comp_block, x_access, fma_tasklet, "_in1", x_subset, x_edge->base_type(), x_edge->debug_info()
5✔
405
    );
5✔
406
    builder.add_computational_memlet(
5✔
407
        comp_block, w_access, fma_tasklet, "_in2", w_subset, w_edge->base_type(), w_edge->debug_info()
5✔
408
    );
5✔
409
    builder.add_computational_memlet(comp_block, accum_read, fma_tasklet, "_in3", {}, scalar_type, block.debug_info());
5✔
410
    builder.add_computational_memlet(comp_block, fma_tasklet, "_out", accum_write, {}, scalar_type, block.debug_info());
5✔
411

412
    // After all loops, write accumulated result to output (with optional bias)
413
    auto& output_block = builder.add_block(*current_scope, {}, block.debug_info());
5✔
414
    auto& accum_final = builder.add_access(output_block, accum_var, block.debug_info());
5✔
415
    auto& y_access = builder.add_access(output_block, Y_var, y_node->debug_info());
5✔
416

417
    // Y shape: [N, C_out, D0_out, ...]
418
    std::vector<symbolic::Expression> y_shape_vec = {N, C_out};
5✔
419
    y_shape_vec.insert(y_shape_vec.end(), output_spatial_dims.begin(), output_spatial_dims.end());
5✔
420

421
    data_flow::Subset y_subset({linearize(output_indices, y_shape_vec)});
5✔
422

423
    if (b_node) {
5✔
424
        // Add bias: output = accum + bias[oc]
NEW
425
        auto& b_access = builder.add_access(output_block, b_node->data(), b_node->debug_info());
×
NEW
426
        auto& add_tasklet =
×
NEW
427
            builder.add_tasklet(output_block, data_flow::fp_add, "_out", {"_in1", "_in2"}, block.debug_info());
×
428

NEW
429
        builder
×
NEW
430
            .add_computational_memlet(output_block, accum_final, add_tasklet, "_in1", {}, scalar_type, block.debug_info());
×
NEW
431
        builder.add_computational_memlet(
×
NEW
432
            output_block, b_access, add_tasklet, "_in2", {oc_var}, b_edge->base_type(), b_edge->debug_info()
×
NEW
433
        );
×
NEW
434
        builder.add_computational_memlet(
×
NEW
435
            output_block, add_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
×
NEW
436
        );
×
437
    } else {
5✔
438
        // No bias: output = accum
439
        auto& assign_tasklet =
5✔
440
            builder.add_tasklet(output_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
5✔
441

442
        builder.add_computational_memlet(
5✔
443
            output_block, accum_final, assign_tasklet, "_in", {}, scalar_type, block.debug_info()
5✔
444
        );
5✔
445
        builder.add_computational_memlet(
5✔
446
            output_block, assign_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
5✔
447
        );
5✔
448
    }
5✔
449

450
    // Clean up the original block
451
    builder.remove_memlet(block, *x_edge);
5✔
452
    builder.remove_memlet(block, *w_edge);
5✔
453
    if (b_edge) {
5✔
NEW
454
        builder.remove_memlet(block, *b_edge);
×
NEW
455
        builder.remove_node(block, *b_node);
×
NEW
456
    }
×
457
    builder.remove_memlet(block, y_edge);
5✔
458
    builder.remove_node(block, *x_node);
5✔
459
    builder.remove_node(block, *w_node);
5✔
460
    builder.remove_node(block, *y_node);
5✔
461
    builder.remove_node(block, *this);
5✔
462
    builder.remove_child(parent, index + 1);
5✔
463

464
    return true;
5✔
465
}
5✔
466

NEW
467
symbolic::SymbolSet ConvNode::symbols() const {
×
NEW
468
    symbolic::SymbolSet syms;
×
469

NEW
470
    for (auto& expr : shape_) {
×
NEW
471
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
472
            syms.insert(atom);
×
NEW
473
        }
×
NEW
474
    }
×
NEW
475
    for (auto& expr : kernel_shape_) {
×
NEW
476
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
477
            syms.insert(atom);
×
NEW
478
        }
×
NEW
479
    }
×
NEW
480
    for (auto& expr : strides_) {
×
NEW
481
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
482
            syms.insert(atom);
×
NEW
483
        }
×
NEW
484
    }
×
NEW
485
    for (auto& expr : pads_) {
×
NEW
486
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
487
            syms.insert(atom);
×
NEW
488
        }
×
NEW
489
    }
×
NEW
490
    for (auto& expr : dilations_) {
×
NEW
491
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
492
            syms.insert(atom);
×
NEW
493
        }
×
NEW
494
    }
×
NEW
495
    for (auto& atom : symbolic::atoms(output_channels_)) {
×
NEW
496
        syms.insert(atom);
×
NEW
497
    }
×
NEW
498
    for (auto& atom : symbolic::atoms(group_)) {
×
NEW
499
        syms.insert(atom);
×
NEW
500
    }
×
501

NEW
502
    return syms;
×
NEW
503
}
×
504

NEW
505
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
NEW
506
    for (auto& expr : shape_) {
×
NEW
507
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
508
    }
×
NEW
509
    for (auto& expr : kernel_shape_) {
×
NEW
510
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
511
    }
×
NEW
512
    for (auto& expr : strides_) {
×
NEW
513
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
514
    }
×
NEW
515
    for (auto& expr : pads_) {
×
NEW
516
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
517
    }
×
NEW
518
    for (auto& expr : dilations_) {
×
NEW
519
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
520
    }
×
NEW
521
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
NEW
522
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
NEW
523
}
×
524

525
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
526
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
527
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
528
        element_id,
1✔
529
        this->debug_info(),
1✔
530
        vertex,
1✔
531
        parent,
1✔
532
        shape_,
1✔
533
        kernel_shape_,
1✔
534
        strides_,
1✔
535
        pads_,
1✔
536
        dilations_,
1✔
537
        output_channels_,
1✔
538
        group_
1✔
539
    ));
1✔
540
}
1✔
541

NEW
542
std::string ConvNode::toStr() const {
×
NEW
543
    std::string result = "Conv(shape=[";
×
NEW
544
    for (size_t i = 0; i < shape_.size(); ++i) {
×
NEW
545
        if (i > 0) result += ", ";
×
NEW
546
        result += shape_[i]->__str__();
×
NEW
547
    }
×
NEW
548
    result += "], kernel_shape=[";
×
NEW
549
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
NEW
550
        if (i > 0) result += ", ";
×
NEW
551
        result += kernel_shape_[i]->__str__();
×
NEW
552
    }
×
NEW
553
    result += "], strides=[";
×
NEW
554
    for (size_t i = 0; i < strides_.size(); ++i) {
×
NEW
555
        if (i > 0) result += ", ";
×
NEW
556
        result += strides_[i]->__str__();
×
NEW
557
    }
×
NEW
558
    result += "], output_channels=" + output_channels_->__str__();
×
NEW
559
    result += ", group=" + group_->__str__() + ")";
×
NEW
560
    return result;
×
NEW
561
}
×
562

NEW
563
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
NEW
564
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
NEW
565
    nlohmann::json j;
×
566

NEW
567
    j["code"] = conv_node.code().value();
×
568

NEW
569
    serializer::JSONSerializer serializer;
×
570

NEW
571
    j["shape"] = nlohmann::json::array();
×
NEW
572
    for (auto& dim : conv_node.shape()) {
×
NEW
573
        j["shape"].push_back(serializer.expression(dim));
×
NEW
574
    }
×
575

NEW
576
    j["kernel_shape"] = nlohmann::json::array();
×
NEW
577
    for (auto& dim : conv_node.kernel_shape()) {
×
NEW
578
        j["kernel_shape"].push_back(serializer.expression(dim));
×
NEW
579
    }
×
580

NEW
581
    j["strides"] = nlohmann::json::array();
×
NEW
582
    for (auto& stride : conv_node.strides()) {
×
NEW
583
        j["strides"].push_back(serializer.expression(stride));
×
NEW
584
    }
×
585

NEW
586
    j["pads"] = nlohmann::json::array();
×
NEW
587
    for (auto& pad : conv_node.pads()) {
×
NEW
588
        j["pads"].push_back(serializer.expression(pad));
×
NEW
589
    }
×
590

NEW
591
    j["dilations"] = nlohmann::json::array();
×
NEW
592
    for (auto& dilation : conv_node.dilations()) {
×
NEW
593
        j["dilations"].push_back(serializer.expression(dilation));
×
NEW
594
    }
×
595

NEW
596
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
NEW
597
    j["group"] = serializer.expression(conv_node.group());
×
598

NEW
599
    return j;
×
NEW
600
}
×
601

602
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
603
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
NEW
604
) {
×
NEW
605
    assert(j.contains("element_id"));
×
NEW
606
    assert(j.contains("code"));
×
NEW
607
    assert(j.contains("debug_info"));
×
NEW
608
    assert(j.contains("kernel_shape"));
×
609

NEW
610
    std::vector<symbolic::Expression> shape;
×
NEW
611
    if (j.contains("shape")) {
×
NEW
612
        for (const auto& dim : j["shape"]) {
×
NEW
613
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
614
        }
×
NEW
615
    }
×
616

NEW
617
    std::vector<symbolic::Expression> kernel_shape;
×
NEW
618
    for (const auto& dim : j["kernel_shape"]) {
×
NEW
619
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
620
    }
×
621

NEW
622
    std::vector<symbolic::Expression> strides;
×
NEW
623
    if (j.contains("strides")) {
×
NEW
624
        for (const auto& stride : j["strides"]) {
×
NEW
625
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
NEW
626
        }
×
NEW
627
    }
×
628

NEW
629
    std::vector<symbolic::Expression> pads;
×
NEW
630
    if (j.contains("pads")) {
×
NEW
631
        for (const auto& pad : j["pads"]) {
×
NEW
632
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
NEW
633
        }
×
NEW
634
    }
×
635

NEW
636
    std::vector<symbolic::Expression> dilations;
×
NEW
637
    if (j.contains("dilations")) {
×
NEW
638
        for (const auto& dilation : j["dilations"]) {
×
NEW
639
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
NEW
640
        }
×
NEW
641
    }
×
642

NEW
643
    symbolic::Expression output_channels = symbolic::one();
×
NEW
644
    if (j.contains("output_channels")) {
×
NEW
645
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
NEW
646
    }
×
647

NEW
648
    symbolic::Expression group = symbolic::one();
×
NEW
649
    if (j.contains("group")) {
×
NEW
650
        group = symbolic::parse(j["group"].get<std::string>());
×
NEW
651
    }
×
652

NEW
653
    sdfg::serializer::JSONSerializer serializer;
×
NEW
654
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
655

NEW
656
    return builder.add_library_node<
×
NEW
657
        ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, output_channels, group);
×
NEW
658
}
×
659

660
} // namespace tensor
661
} // namespace math
662
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc