• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20895298121

11 Jan 2026 12:41PM UTC coverage: 62.345% (-0.06%) from 62.402%
20895298121

Pull #423

github

web-flow
Merge 0e0a7fb17 into 7b68d7fab
Pull Request #423: Add ConvNode compatible with ONNX Conv operator with n-dimensional expansion and custom validation

319 of 532 new or added lines in 21 files covered. (59.96%)

2 existing lines in 2 files now uncovered.

15403 of 24706 relevant lines covered (62.35%)

88.69 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

59.66
/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4

5
#include "sdfg/analysis/analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7
#include "sdfg/types/type.h"
8

9
#include "sdfg/analysis/scope_analysis.h"
10
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
11

12
namespace sdfg {
13
namespace math {
14
namespace tensor {
15

16
ConvNode::ConvNode(
17
    size_t element_id,
18
    const DebugInfo& debug_info,
19
    const graph::Vertex vertex,
20
    data_flow::DataFlowGraph& parent,
21
    const std::vector<symbolic::Expression>& shape,
22
    const std::vector<symbolic::Expression>& kernel_shape,
23
    const std::vector<symbolic::Expression>& strides,
24
    const std::vector<symbolic::Expression>& pads,
25
    const std::vector<symbolic::Expression>& dilations,
26
    symbolic::Expression group
27
)
28
    : TensorNode(
22✔
29
          element_id,
22✔
30
          debug_info,
22✔
31
          vertex,
22✔
32
          parent,
22✔
33
          LibraryNodeType_Conv,
22✔
34
          {"Y"},
22✔
35
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
22✔
36
          data_flow::ImplementationType_NONE
22✔
37
      ),
22✔
38
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations), group_(group) {
22✔
39
}
22✔
40

41
void ConvNode::validate(const Function& function) const {
17✔
42
    TensorNode::validate(function);
17✔
43

44
    auto& graph = this->get_parent();
17✔
45

46
    // Custom validation for ConvNode that handles optional bias input
47
    // We expect X, W as required inputs and optionally B (bias)
48

49
    // Collect all input edges by connector name
50
    std::map<std::string, const data_flow::Memlet*> input_edges;
17✔
51
    for (auto& iedge : graph.in_edges(*this)) {
34✔
52
        input_edges[iedge.dst_conn()] = &iedge;
34✔
53
    }
34✔
54

55
    // Check that required inputs X and W are present
56
    if (input_edges.find("X") == input_edges.end()) {
17✔
NEW
57
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
NEW
58
    }
×
59
    if (input_edges.find("W") == input_edges.end()) {
17✔
NEW
60
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
NEW
61
    }
×
62

63
    // Validate kernel shape is not empty
64
    if (kernel_shape_.empty()) {
17✔
NEW
65
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
NEW
66
    }
×
67

68
    // Validate strides, pads, dilations have consistent dimensions
69
    size_t spatial_dims = kernel_shape_.size();
17✔
70

71
    if (!strides_.empty() && strides_.size() != spatial_dims) {
17✔
72
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
1✔
73
    }
1✔
74

75
    if (!pads_.empty() && pads_.size() != 2 * spatial_dims) {
16✔
76
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
1✔
77
    }
1✔
78

79
    if (!dilations_.empty() && dilations_.size() != spatial_dims) {
15✔
NEW
80
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
NEW
81
    }
×
82
}
15✔
83

84
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
5✔
85
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
5✔
86

87
    auto& dataflow = this->get_parent();
5✔
88
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
5✔
89
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
5✔
90
    int index = parent.index(block);
5✔
91
    auto& transition = parent.at(index).second;
5✔
92

93
    // Get primitive type
94
    auto primitive_type = this->primitive_type(dataflow);
5✔
95
    types::Scalar scalar_type(primitive_type);
5✔
96

97
    // Get input edges
98
    auto in_edges = dataflow.in_edges(*this);
5✔
99
    auto in_edges_it = in_edges.begin();
5✔
100

101
    data_flow::Memlet* x_edge = nullptr;
5✔
102
    data_flow::Memlet* w_edge = nullptr;
5✔
103
    data_flow::Memlet* b_edge = nullptr;
5✔
104

105
    while (in_edges_it != in_edges.end()) {
15✔
106
        auto& edge = *in_edges_it;
10✔
107
        auto dst_conn = edge.dst_conn();
10✔
108
        if (dst_conn == "X") {
10✔
109
            x_edge = &edge;
5✔
110
        } else if (dst_conn == "W") {
5✔
111
            w_edge = &edge;
5✔
112
        } else if (dst_conn == "B") {
5✔
NEW
113
            b_edge = &edge;
×
NEW
114
        } else {
×
NEW
115
            throw InvalidSDFGException("ConvNode has unexpected input: " + dst_conn);
×
NEW
116
        }
×
117
        ++in_edges_it;
10✔
118
    }
10✔
119

120
    if (!x_edge || !w_edge) {
5✔
NEW
121
        throw InvalidSDFGException("ConvNode requires X and W inputs");
×
NEW
122
    }
×
123

124
    auto& y_edge = *dataflow.out_edges(*this).begin();
5✔
125

126
    // Get access nodes
127
    auto* x_node = static_cast<data_flow::AccessNode*>(&x_edge->src());
5✔
128
    auto* w_node = static_cast<data_flow::AccessNode*>(&w_edge->src());
5✔
129
    data_flow::AccessNode* b_node = b_edge ? static_cast<data_flow::AccessNode*>(&b_edge->src()) : nullptr;
5✔
130
    auto* y_node = static_cast<data_flow::AccessNode*>(&y_edge.dst());
5✔
131

132
    // Validate nodes are standalone in the block
133
    if (!x_node || dataflow.in_degree(*x_node) != 0 || !w_node || dataflow.in_degree(*w_node) != 0 || !y_node ||
5✔
134
        dataflow.out_degree(*y_node) != 0) {
5✔
NEW
135
        return false;
×
NEW
136
    }
×
137

138
    if (b_node && dataflow.in_degree(*b_node) != 0) {
5✔
NEW
139
        return false;
×
NEW
140
    }
×
141

142
    // Check that all other nodes in the block are the expected ones
143
    for (auto* nd : dataflow.data_nodes()) {
15✔
144
        if (nd != x_node && nd != w_node && nd != y_node && (!b_node || nd != b_node)) {
15✔
NEW
145
            return false; // there are other nodes we cannot handle
×
NEW
146
        }
×
147
    }
15✔
148

149
    // Support n-dimensional convolutions
150
    size_t spatial_dims = kernel_shape_.size();
5✔
151

152
    if (spatial_dims == 0) {
5✔
NEW
153
        return false; // Need at least 1 spatial dimension
×
NEW
154
    }
×
155

156
    // Get strides (default to 1 if not provided)
157
    std::vector<symbolic::Expression> strides_vec;
5✔
158
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
159
        if (i < strides_.size()) {
10✔
160
            strides_vec.push_back(strides_[i]);
10✔
161
        } else {
10✔
NEW
162
            strides_vec.push_back(symbolic::one());
×
NEW
163
        }
×
164
    }
10✔
165

166
    // Get padding (default to 0 if not provided)
167
    // Pads format: [begin_0, begin_1, ..., begin_n, end_0, end_1, ..., end_n]
168
    std::vector<symbolic::Expression> pads_begin_vec;
5✔
169
    std::vector<symbolic::Expression> pads_end_vec;
5✔
170
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
171
        if (i < pads_.size()) {
10✔
172
            pads_begin_vec.push_back(pads_[i]);
10✔
173
        } else {
10✔
NEW
174
            pads_begin_vec.push_back(symbolic::zero());
×
NEW
175
        }
×
176

177
        if (spatial_dims + i < pads_.size()) {
10✔
178
            pads_end_vec.push_back(pads_[spatial_dims + i]);
10✔
179
        } else {
10✔
NEW
180
            pads_end_vec.push_back(symbolic::zero());
×
NEW
181
        }
×
182
    }
10✔
183

184
    // Get dilations (default to 1 if not provided)
185
    std::vector<symbolic::Expression> dilations_vec;
5✔
186
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
187
        if (i < dilations_.size()) {
10✔
188
            dilations_vec.push_back(dilations_[i]);
10✔
189
        } else {
10✔
NEW
190
            dilations_vec.push_back(symbolic::one());
×
NEW
191
        }
×
192
    }
10✔
193

194
    // Get variable names
195
    auto& X_var = x_node->data();
5✔
196
    auto& W_var = w_node->data();
5✔
197
    auto& Y_var = y_node->data();
5✔
198

199
    // Use shape_ for dimensions if available
200
    // For a generic n-dimensional implementation:
201
    // Input X shape: [N, C_in, D0_in, D1_in, ..., Dn_in]
202
    symbolic::Expression N, C_in;
5✔
203
    std::vector<symbolic::Expression> input_spatial_dims;
5✔
204

205
    if (shape_.size() >= 2 + spatial_dims) {
5✔
206
        N = shape_[0];
5✔
207
        C_in = shape_[1];
5✔
208
        for (size_t i = 0; i < spatial_dims; ++i) {
15✔
209
            input_spatial_dims.push_back(shape_[2 + i]);
10✔
210
        }
10✔
211
    } else {
5✔
NEW
212
        N = symbolic::symbol(builder.find_new_name("N"));
×
NEW
213
        C_in = symbolic::symbol(builder.find_new_name("C_in"));
×
NEW
214
        for (size_t i = 0; i < spatial_dims; ++i) {
×
NEW
215
            input_spatial_dims.push_back(symbolic::symbol(builder.find_new_name("D" + std::to_string(i) + "_in")));
×
NEW
216
        }
×
NEW
217
    }
×
218

219
    // Output Channel (C_out) is not in input shape, treat as symbol
220
    auto C_out = symbolic::symbol(builder.find_new_name("C_out"));
5✔
221

222
    // Calculate output spatial dimensions
223
    std::vector<symbolic::Expression> output_spatial_dims;
5✔
224
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
225
        // D_out = floor((D_in + pads_begin + pads_end - dilation * (kernel - 1) - 1) / stride) + 1
226
        auto d_in = input_spatial_dims[i];
10✔
227
        auto pad = symbolic::add(pads_begin_vec[i], pads_end_vec[i]);
10✔
228
        auto dk = symbolic::mul(dilations_vec[i], symbolic::sub(kernel_shape_[i], symbolic::one()));
10✔
229
        auto num = symbolic::sub(symbolic::add(d_in, pad), symbolic::add(dk, symbolic::one()));
10✔
230
        auto d_out = symbolic::add(symbolic::div(num, strides_vec[i]), symbolic::one());
10✔
231
        output_spatial_dims.push_back(d_out);
10✔
232
    }
10✔
233

234
    // Create new sequence for expansion
235
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
5✔
236

237
    // Create nested map structure for convolution
238
    structured_control_flow::Sequence* current_scope = &new_sequence;
5✔
239
    std::vector<symbolic::Expression> output_indices;
5✔
240
    std::vector<symbolic::Expression> output_spatial_vars;
5✔
241

242
    // Map over batch dimension
243
    std::string n_str = builder.find_new_name("n");
5✔
244
    builder.add_container(n_str, types::Scalar(types::PrimitiveType::UInt64));
5✔
245
    auto n_var = symbolic::symbol(n_str);
5✔
246
    auto& map_n = builder.add_map(
5✔
247
        *current_scope,
5✔
248
        n_var,
5✔
249
        symbolic::Lt(n_var, N),
5✔
250
        symbolic::zero(),
5✔
251
        symbolic::add(n_var, symbolic::one()),
5✔
252
        structured_control_flow::ScheduleType_Sequential::create(),
5✔
253
        {},
5✔
254
        block.debug_info()
5✔
255
    );
5✔
256
    current_scope = &map_n.root();
5✔
257
    output_indices.push_back(n_var);
5✔
258

259
    // Map over output channel dimension
260
    std::string oc_str = builder.find_new_name("oc");
5✔
261
    builder.add_container(oc_str, types::Scalar(types::PrimitiveType::UInt64));
5✔
262
    auto oc_var = symbolic::symbol(oc_str);
5✔
263
    auto& map_oc = builder.add_map(
5✔
264
        *current_scope,
5✔
265
        oc_var,
5✔
266
        symbolic::Lt(oc_var, C_out),
5✔
267
        symbolic::zero(),
5✔
268
        symbolic::add(oc_var, symbolic::one()),
5✔
269
        structured_control_flow::ScheduleType_Sequential::create(),
5✔
270
        {},
5✔
271
        block.debug_info()
5✔
272
    );
5✔
273
    current_scope = &map_oc.root();
5✔
274
    output_indices.push_back(oc_var);
5✔
275

276
    // Map over each output spatial dimension dynamically
277
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
278
        std::string od_str = builder.find_new_name("od" + std::to_string(i));
10✔
279
        builder.add_container(od_str, types::Scalar(types::PrimitiveType::UInt64));
10✔
280
        auto od_var = symbolic::symbol(od_str);
10✔
281
        auto& map_od = builder.add_map(
10✔
282
            *current_scope,
10✔
283
            od_var,
10✔
284
            symbolic::Lt(od_var, output_spatial_dims[i]),
10✔
285
            symbolic::zero(),
10✔
286
            symbolic::add(od_var, symbolic::one()),
10✔
287
            structured_control_flow::ScheduleType_Sequential::create(),
10✔
288
            {},
10✔
289
            block.debug_info()
10✔
290
        );
10✔
291
        current_scope = &map_od.root();
10✔
292
        output_indices.push_back(od_var);
10✔
293
        output_spatial_vars.push_back(od_var);
10✔
294
    }
10✔
295

296
    // Create accumulator variable for the sum
297
    std::string accum_var = builder.find_new_name("_conv_accum");
5✔
298
    builder.add_container(accum_var, scalar_type);
5✔
299

300
    // Initialize accumulator to 0
301
    auto& init_block = builder.add_block(*current_scope, {}, block.debug_info());
5✔
302
    auto& accum_init = builder.add_access(init_block, accum_var, block.debug_info());
5✔
303
    auto& zero_const = builder.add_constant(init_block, "0.0", scalar_type, block.debug_info());
5✔
304
    auto& init_tasklet = builder.add_tasklet(init_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
5✔
305
    builder.add_computational_memlet(init_block, zero_const, init_tasklet, "_in", {}, scalar_type, block.debug_info());
5✔
306
    builder.add_computational_memlet(init_block, init_tasklet, "_out", accum_init, {}, scalar_type, block.debug_info());
5✔
307

308
    // Create nested for loops for input channels and kernel dimensions
309
    // For loop over input channels
310
    std::string ic_str = builder.find_new_name("ic");
5✔
311
    builder.add_container(ic_str, types::Scalar(types::PrimitiveType::UInt64));
5✔
312
    auto ic_var = symbolic::symbol(ic_str);
5✔
313
    auto& for_ic = builder.add_for(
5✔
314
        *current_scope,
5✔
315
        ic_var,
5✔
316
        symbolic::Lt(ic_var, C_in),
5✔
317
        symbolic::zero(),
5✔
318
        symbolic::add(ic_var, symbolic::one()),
5✔
319
        {},
5✔
320
        block.debug_info()
5✔
321
    );
5✔
322
    auto* loop_scope = &for_ic.root();
5✔
323

324
    // For loops over each kernel spatial dimension
325
    std::vector<symbolic::Expression> kernel_vars;
5✔
326
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
327
        std::string k_str = builder.find_new_name("k" + std::to_string(i));
10✔
328
        builder.add_container(k_str, types::Scalar(types::PrimitiveType::UInt64));
10✔
329
        auto k_var = symbolic::symbol(k_str);
10✔
330
        auto& for_k = builder.add_for(
10✔
331
            *loop_scope,
10✔
332
            k_var,
10✔
333
            symbolic::Lt(k_var, kernel_shape_[i]),
10✔
334
            symbolic::zero(),
10✔
335
            symbolic::add(k_var, symbolic::one()),
10✔
336
            {},
10✔
337
            block.debug_info()
10✔
338
        );
10✔
339
        loop_scope = &for_k.root();
10✔
340
        kernel_vars.push_back(k_var);
10✔
341
    }
10✔
342

343
    // Compute indices for input and weight access
344
    // Input index: [n, ic, od0 * stride0 - pad0 + k0 * dilation0, ...]
345
    // Note: taking dilation into account for input index calculation
346
    std::vector<symbolic::Expression> input_spatial_indices;
5✔
347
    for (size_t i = 0; i < spatial_dims; ++i) {
15✔
348
        auto k_dilated = symbolic::mul(kernel_vars[i], dilations_vec[i]);
10✔
349
        auto input_idx = symbolic::
10✔
350
            add(symbolic::sub(symbolic::mul(output_spatial_vars[i], strides_vec[i]), pads_begin_vec[i]), k_dilated);
10✔
351
        input_spatial_indices.push_back(input_idx);
10✔
352
    }
10✔
353

354
    // Create computation block
355
    auto& comp_block = builder.add_block(*loop_scope, {}, block.debug_info());
5✔
356

357
    // Access input X[n, ic, input_spatial_indices...]
358
    auto& x_access = builder.add_access(comp_block, X_var, x_node->debug_info());
5✔
359
    // Access weight W[oc, ic, k0, k1, ...]
360
    auto& w_access = builder.add_access(comp_block, W_var, w_node->debug_info());
5✔
361
    // Access accumulator
362
    auto& accum_read = builder.add_access(comp_block, accum_var, block.debug_info());
5✔
363
    auto& accum_write = builder.add_access(comp_block, accum_var, block.debug_info());
5✔
364

365
    // Create FMA tasklet: accum = accum + x * w
366
    auto& fma_tasklet =
5✔
367
        builder.add_tasklet(comp_block, data_flow::fp_fma, "_out", {"_in1", "_in2", "_in3"}, block.debug_info());
5✔
368

369
    // Linearization helper
370
    auto linearize = [&](const std::vector<symbolic::Expression>& indices,
5✔
371
                         const std::vector<symbolic::Expression>& shape) -> symbolic::Expression {
15✔
372
        symbolic::Expression idx = symbolic::zero();
15✔
373
        symbolic::Expression stride = symbolic::one();
15✔
374
        for (int i = shape.size() - 1; i >= 0; --i) {
75✔
375
            idx = symbolic::add(idx, symbolic::mul(indices[i], stride));
60✔
376
            stride = symbolic::mul(stride, shape[i]);
60✔
377
        }
60✔
378
        return idx;
15✔
379
    };
15✔
380

381
    // Calculate shapes for linearization
382
    // X shape: [N, C_in, D0, D1...]
383
    std::vector<symbolic::Expression> x_shape_vec = {N, C_in};
5✔
384
    x_shape_vec.insert(x_shape_vec.end(), input_spatial_dims.begin(), input_spatial_dims.end());
5✔
385

386
    // W shape: [C_out, C_in/group, k0, k1...]
387
    std::vector<symbolic::Expression> w_shape_vec = {C_out, symbolic::div(C_in, group_)};
5✔
388
    w_shape_vec.insert(w_shape_vec.end(), kernel_shape_.begin(), kernel_shape_.end());
5✔
389

390
    // Connect edges with linearized subsets
391
    std::vector<symbolic::Expression> x_indices_vec = {n_var, ic_var};
5✔
392
    x_indices_vec.insert(x_indices_vec.end(), input_spatial_indices.begin(), input_spatial_indices.end());
5✔
393

394
    std::vector<symbolic::Expression> w_indices_vec = {oc_var, ic_var}; // Assuming group=1 for now for simplicity of
5✔
395
                                                                        // indices
396
    // TODO: Handle groups properly in indices if needed, but for standard conv:
397
    w_indices_vec.insert(w_indices_vec.end(), kernel_vars.begin(), kernel_vars.end());
5✔
398

399
    data_flow::Subset x_subset({linearize(x_indices_vec, x_shape_vec)});
5✔
400
    data_flow::Subset w_subset({linearize(w_indices_vec, w_shape_vec)});
5✔
401

402
    builder.add_computational_memlet(
5✔
403
        comp_block, x_access, fma_tasklet, "_in1", x_subset, x_edge->base_type(), x_edge->debug_info()
5✔
404
    );
5✔
405
    builder.add_computational_memlet(
5✔
406
        comp_block, w_access, fma_tasklet, "_in2", w_subset, w_edge->base_type(), w_edge->debug_info()
5✔
407
    );
5✔
408
    builder.add_computational_memlet(comp_block, accum_read, fma_tasklet, "_in3", {}, scalar_type, block.debug_info());
5✔
409
    builder.add_computational_memlet(comp_block, fma_tasklet, "_out", accum_write, {}, scalar_type, block.debug_info());
5✔
410

411
    // After all loops, write accumulated result to output (with optional bias)
412
    auto& output_block = builder.add_block(*current_scope, {}, block.debug_info());
5✔
413
    auto& accum_final = builder.add_access(output_block, accum_var, block.debug_info());
5✔
414
    auto& y_access = builder.add_access(output_block, Y_var, y_node->debug_info());
5✔
415

416
    // Y shape: [N, C_out, D0_out, ...]
417
    std::vector<symbolic::Expression> y_shape_vec = {N, C_out};
5✔
418
    y_shape_vec.insert(y_shape_vec.end(), output_spatial_dims.begin(), output_spatial_dims.end());
5✔
419

420
    data_flow::Subset y_subset({linearize(output_indices, y_shape_vec)});
5✔
421

422
    if (b_node) {
5✔
423
        // Add bias: output = accum + bias[oc]
NEW
424
        auto& b_access = builder.add_access(output_block, b_node->data(), b_node->debug_info());
×
NEW
425
        auto& add_tasklet =
×
NEW
426
            builder.add_tasklet(output_block, data_flow::fp_add, "_out", {"_in1", "_in2"}, block.debug_info());
×
427

NEW
428
        builder
×
NEW
429
            .add_computational_memlet(output_block, accum_final, add_tasklet, "_in1", {}, scalar_type, block.debug_info());
×
NEW
430
        builder.add_computational_memlet(
×
NEW
431
            output_block, b_access, add_tasklet, "_in2", {oc_var}, b_edge->base_type(), b_edge->debug_info()
×
NEW
432
        );
×
NEW
433
        builder.add_computational_memlet(
×
NEW
434
            output_block, add_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
×
NEW
435
        );
×
436
    } else {
5✔
437
        // No bias: output = accum
438
        auto& assign_tasklet =
5✔
439
            builder.add_tasklet(output_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
5✔
440

441
        builder.add_computational_memlet(
5✔
442
            output_block, accum_final, assign_tasklet, "_in", {}, scalar_type, block.debug_info()
5✔
443
        );
5✔
444
        builder.add_computational_memlet(
5✔
445
            output_block, assign_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
5✔
446
        );
5✔
447
    }
5✔
448

449
    // Clean up the original block
450
    builder.remove_memlet(block, *x_edge);
5✔
451
    builder.remove_memlet(block, *w_edge);
5✔
452
    if (b_edge) {
5✔
NEW
453
        builder.remove_memlet(block, *b_edge);
×
NEW
454
        builder.remove_node(block, *b_node);
×
NEW
455
    }
×
456
    builder.remove_memlet(block, y_edge);
5✔
457
    builder.remove_node(block, *x_node);
5✔
458
    builder.remove_node(block, *w_node);
5✔
459
    builder.remove_node(block, *y_node);
5✔
460
    builder.remove_node(block, *this);
5✔
461
    builder.remove_child(parent, index + 1);
5✔
462

463
    return true;
5✔
464
}
5✔
465

NEW
466
symbolic::SymbolSet ConvNode::symbols() const {
×
NEW
467
    symbolic::SymbolSet syms;
×
468

NEW
469
    for (auto& expr : shape_) {
×
NEW
470
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
471
            syms.insert(atom);
×
NEW
472
        }
×
NEW
473
    }
×
NEW
474
    for (auto& expr : kernel_shape_) {
×
NEW
475
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
476
            syms.insert(atom);
×
NEW
477
        }
×
NEW
478
    }
×
NEW
479
    for (auto& expr : strides_) {
×
NEW
480
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
481
            syms.insert(atom);
×
NEW
482
        }
×
NEW
483
    }
×
NEW
484
    for (auto& expr : pads_) {
×
NEW
485
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
486
            syms.insert(atom);
×
NEW
487
        }
×
NEW
488
    }
×
NEW
489
    for (auto& expr : dilations_) {
×
NEW
490
        for (auto& atom : symbolic::atoms(expr)) {
×
NEW
491
            syms.insert(atom);
×
NEW
492
        }
×
NEW
493
    }
×
NEW
494
    for (auto& atom : symbolic::atoms(group_)) {
×
NEW
495
        syms.insert(atom);
×
NEW
496
    }
×
497

NEW
498
    return syms;
×
NEW
499
}
×
500

NEW
501
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
NEW
502
    for (auto& expr : shape_) {
×
NEW
503
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
504
    }
×
NEW
505
    for (auto& expr : kernel_shape_) {
×
NEW
506
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
507
    }
×
NEW
508
    for (auto& expr : strides_) {
×
NEW
509
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
510
    }
×
NEW
511
    for (auto& expr : pads_) {
×
NEW
512
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
513
    }
×
NEW
514
    for (auto& expr : dilations_) {
×
NEW
515
        expr = symbolic::subs(expr, old_expression, new_expression);
×
NEW
516
    }
×
NEW
517
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
NEW
518
}
×
519

520
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
521
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
1✔
522
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
1✔
523
        element_id, this->debug_info(), vertex, parent, shape_, kernel_shape_, strides_, pads_, dilations_, group_
1✔
524
    ));
1✔
525
}
1✔
526

NEW
527
std::string ConvNode::toStr() const {
×
NEW
528
    std::string result = "Conv(shape=[";
×
NEW
529
    for (size_t i = 0; i < shape_.size(); ++i) {
×
NEW
530
        if (i > 0) result += ", ";
×
NEW
531
        result += shape_[i]->__str__();
×
NEW
532
    }
×
NEW
533
    result += "], kernel_shape=[";
×
NEW
534
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
NEW
535
        if (i > 0) result += ", ";
×
NEW
536
        result += kernel_shape_[i]->__str__();
×
NEW
537
    }
×
NEW
538
    result += "], strides=[";
×
NEW
539
    for (size_t i = 0; i < strides_.size(); ++i) {
×
NEW
540
        if (i > 0) result += ", ";
×
NEW
541
        result += strides_[i]->__str__();
×
NEW
542
    }
×
NEW
543
    result += "], group=" + group_->__str__() + ")";
×
NEW
544
    return result;
×
NEW
545
}
×
546

NEW
547
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
NEW
548
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
NEW
549
    nlohmann::json j;
×
550

NEW
551
    j["code"] = conv_node.code().value();
×
552

NEW
553
    serializer::JSONSerializer serializer;
×
554

NEW
555
    j["shape"] = nlohmann::json::array();
×
NEW
556
    for (auto& dim : conv_node.shape()) {
×
NEW
557
        j["shape"].push_back(serializer.expression(dim));
×
NEW
558
    }
×
559

NEW
560
    j["kernel_shape"] = nlohmann::json::array();
×
NEW
561
    for (auto& dim : conv_node.kernel_shape()) {
×
NEW
562
        j["kernel_shape"].push_back(serializer.expression(dim));
×
NEW
563
    }
×
564

NEW
565
    j["strides"] = nlohmann::json::array();
×
NEW
566
    for (auto& stride : conv_node.strides()) {
×
NEW
567
        j["strides"].push_back(serializer.expression(stride));
×
NEW
568
    }
×
569

NEW
570
    j["pads"] = nlohmann::json::array();
×
NEW
571
    for (auto& pad : conv_node.pads()) {
×
NEW
572
        j["pads"].push_back(serializer.expression(pad));
×
NEW
573
    }
×
574

NEW
575
    j["dilations"] = nlohmann::json::array();
×
NEW
576
    for (auto& dilation : conv_node.dilations()) {
×
NEW
577
        j["dilations"].push_back(serializer.expression(dilation));
×
NEW
578
    }
×
579

NEW
580
    j["group"] = serializer.expression(conv_node.group());
×
581

NEW
582
    return j;
×
NEW
583
}
×
584

585
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
586
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
NEW
587
) {
×
NEW
588
    assert(j.contains("element_id"));
×
NEW
589
    assert(j.contains("code"));
×
NEW
590
    assert(j.contains("debug_info"));
×
NEW
591
    assert(j.contains("kernel_shape"));
×
592

NEW
593
    std::vector<symbolic::Expression> shape;
×
NEW
594
    if (j.contains("shape")) {
×
NEW
595
        for (const auto& dim : j["shape"]) {
×
NEW
596
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
597
        }
×
NEW
598
    }
×
599

NEW
600
    std::vector<symbolic::Expression> kernel_shape;
×
NEW
601
    for (const auto& dim : j["kernel_shape"]) {
×
NEW
602
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
NEW
603
    }
×
604

NEW
605
    std::vector<symbolic::Expression> strides;
×
NEW
606
    if (j.contains("strides")) {
×
NEW
607
        for (const auto& stride : j["strides"]) {
×
NEW
608
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
NEW
609
        }
×
NEW
610
    }
×
611

NEW
612
    std::vector<symbolic::Expression> pads;
×
NEW
613
    if (j.contains("pads")) {
×
NEW
614
        for (const auto& pad : j["pads"]) {
×
NEW
615
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
NEW
616
        }
×
NEW
617
    }
×
618

NEW
619
    std::vector<symbolic::Expression> dilations;
×
NEW
620
    if (j.contains("dilations")) {
×
NEW
621
        for (const auto& dilation : j["dilations"]) {
×
NEW
622
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
NEW
623
        }
×
NEW
624
    }
×
625

NEW
626
    symbolic::Expression group = symbolic::one();
×
NEW
627
    if (j.contains("group")) {
×
NEW
628
        group = symbolic::parse(j["group"].get<std::string>());
×
NEW
629
    }
×
630

NEW
631
    sdfg::serializer::JSONSerializer serializer;
×
NEW
632
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
633

NEW
634
    return builder.add_library_node<ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, group);
×
NEW
635
}
×
636

637
} // namespace tensor
638
} // namespace math
639
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc