• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22992592591

12 Mar 2026 08:14AM UTC coverage: 63.488% (-0.007%) from 63.495%
22992592591

push

github

web-flow
Merge pull request #576 from daisytuner/depthwise2DConv

[MLIR] Depthwise2dconv

0 of 7 new or added lines in 1 file covered. (0.0%)

2 existing lines in 1 file now uncovered.

24709 of 38919 relevant lines covered (63.49%)

369.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/sdfg/src/data_flow/library_nodes/math/tensor/conv_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/conv_node.h"
2

3
#include <map>
4

5
#include "sdfg/analysis/analysis.h"
6
#include "sdfg/builder/structured_sdfg_builder.h"
7
#include "sdfg/types/type.h"
8

9
#include "sdfg/analysis/scope_analysis.h"
10
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
11

12
namespace sdfg {
13
namespace math {
14
namespace tensor {
15

16
ConvNode::ConvNode(
17
    size_t element_id,
18
    const DebugInfo& debug_info,
19
    const graph::Vertex vertex,
20
    data_flow::DataFlowGraph& parent,
21
    const std::vector<symbolic::Expression>& shape,
22
    const std::vector<symbolic::Expression>& kernel_shape,
23
    const std::vector<symbolic::Expression>& strides,
24
    const std::vector<symbolic::Expression>& pads,
25
    const std::vector<symbolic::Expression>& dilations,
26
    symbolic::Expression output_channels,
27
    symbolic::Expression group
28
)
29
    : TensorNode(
×
30
          element_id,
×
31
          debug_info,
×
32
          vertex,
×
33
          parent,
×
34
          LibraryNodeType_Conv,
×
35
          {"Y"},
×
36
          {"X", "W", "B"}, // X and W are required, B (bias) is optional
×
37
          data_flow::ImplementationType_NONE
×
38
      ),
×
39
      shape_(shape), kernel_shape_(kernel_shape), strides_(strides), pads_(pads), dilations_(dilations),
×
40
      output_channels_(output_channels), group_(group) {}
×
41

42
void ConvNode::validate(const Function& function) const {
×
43
    TensorNode::validate(function);
×
44

45
    auto& graph = this->get_parent();
×
46

47
    // Custom validation for ConvNode that handles optional bias input
48
    // We expect X, W as required inputs and optionally B (bias)
49

50
    // Collect all input edges by connector name
51
    std::map<std::string, const data_flow::Memlet*> input_edges;
×
52
    for (auto& iedge : graph.in_edges(*this)) {
×
53
        input_edges[iedge.dst_conn()] = &iedge;
×
54
    }
×
55

56
    // Check that required inputs X and W are present
57
    if (input_edges.find("X") == input_edges.end()) {
×
58
        throw InvalidSDFGException("ConvNode: Required input 'X' is not connected");
×
59
    }
×
60
    if (input_edges.find("W") == input_edges.end()) {
×
61
        throw InvalidSDFGException("ConvNode: Required input 'W' is not connected");
×
62
    }
×
63

64
    // Validate kernel shape is not empty
65
    if (kernel_shape_.empty()) {
×
66
        throw InvalidSDFGException("ConvNode kernel_shape cannot be empty");
×
67
    }
×
68

69
    // Validate strides, pads, dilations have consistent dimensions
70
    size_t spatial_dims = kernel_shape_.size();
×
71

72
    if (!strides_.empty() && strides_.size() != spatial_dims) {
×
73
        throw InvalidSDFGException("ConvNode strides must match kernel spatial dimensions");
×
74
    }
×
75

76
    if (!pads_.empty() && pads_.size() != 2 * spatial_dims) {
×
77
        throw InvalidSDFGException("ConvNode pads must have 2 * spatial dimensions (start and end for each axis)");
×
78
    }
×
79

80
    if (!dilations_.empty() && dilations_.size() != spatial_dims) {
×
81
        throw InvalidSDFGException("ConvNode dilations must match kernel spatial dimensions");
×
82
    }
×
83
}
×
84

85
bool ConvNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
86
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
87

88
    auto& dataflow = this->get_parent();
×
89
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
90
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
91
    int index = parent.index(block);
×
92
    auto& transition = parent.at(index).second;
×
93

94
    // Get primitive type
95
    auto primitive_type = this->primitive_type(dataflow);
×
96
    types::Scalar scalar_type(primitive_type);
×
97

98
    // Get input edges
99
    auto in_edges = dataflow.in_edges(*this);
×
100
    auto in_edges_it = in_edges.begin();
×
101

102
    data_flow::Memlet* x_edge = nullptr;
×
103
    data_flow::Memlet* w_edge = nullptr;
×
104
    data_flow::Memlet* b_edge = nullptr;
×
105

106
    while (in_edges_it != in_edges.end()) {
×
107
        auto& edge = *in_edges_it;
×
108
        auto dst_conn = edge.dst_conn();
×
109
        if (dst_conn == "X") {
×
110
            x_edge = &edge;
×
111
        } else if (dst_conn == "W") {
×
112
            w_edge = &edge;
×
113
        } else if (dst_conn == "B") {
×
114
            b_edge = &edge;
×
115
        } else {
×
116
            throw InvalidSDFGException("ConvNode has unexpected input: " + dst_conn);
×
117
        }
×
118
        ++in_edges_it;
×
119
    }
×
120

121
    if (!x_edge || !w_edge) {
×
122
        throw InvalidSDFGException("ConvNode requires X and W inputs");
×
123
    }
×
124

125
    auto& y_edge = *dataflow.out_edges(*this).begin();
×
126

127
    // Get access nodes
128
    auto* x_node = static_cast<data_flow::AccessNode*>(&x_edge->src());
×
129
    auto* w_node = static_cast<data_flow::AccessNode*>(&w_edge->src());
×
130
    data_flow::AccessNode* b_node = b_edge ? static_cast<data_flow::AccessNode*>(&b_edge->src()) : nullptr;
×
131
    auto* y_node = static_cast<data_flow::AccessNode*>(&y_edge.dst());
×
132

133
    // Validate nodes are standalone in the block
134
    if (!x_node || dataflow.in_degree(*x_node) != 0 || !w_node || dataflow.in_degree(*w_node) != 0 || !y_node ||
×
135
        dataflow.out_degree(*y_node) != 0) {
×
136
        return false;
×
137
    }
×
138

139
    if (b_node && dataflow.in_degree(*b_node) != 0) {
×
140
        return false;
×
141
    }
×
142

143
    // Check that all other nodes in the block are the expected ones
144
    for (auto* nd : dataflow.data_nodes()) {
×
145
        if (nd != x_node && nd != w_node && nd != y_node && (!b_node || nd != b_node)) {
×
146
            return false; // there are other nodes we cannot handle
×
147
        }
×
148
    }
×
149

150
    // Support n-dimensional convolutions
151
    size_t spatial_dims = kernel_shape_.size();
×
152

153
    if (spatial_dims == 0) {
×
154
        return false; // Need at least 1 spatial dimension
×
155
    }
×
156

157
    // Get strides (default to 1 if not provided)
158
    std::vector<symbolic::Expression> strides_vec;
×
159
    for (size_t i = 0; i < spatial_dims; ++i) {
×
160
        if (i < strides_.size()) {
×
161
            strides_vec.push_back(strides_[i]);
×
162
        } else {
×
163
            strides_vec.push_back(symbolic::one());
×
164
        }
×
165
    }
×
166

167
    // Get padding (default to 0 if not provided)
168
    // Pads format: [begin_0, begin_1, ..., begin_n, end_0, end_1, ..., end_n]
169
    std::vector<symbolic::Expression> pads_begin_vec;
×
170
    std::vector<symbolic::Expression> pads_end_vec;
×
171
    for (size_t i = 0; i < spatial_dims; ++i) {
×
172
        if (i < pads_.size()) {
×
173
            pads_begin_vec.push_back(pads_[i]);
×
174
        } else {
×
175
            pads_begin_vec.push_back(symbolic::zero());
×
176
        }
×
177

178
        if (spatial_dims + i < pads_.size()) {
×
179
            pads_end_vec.push_back(pads_[spatial_dims + i]);
×
180
        } else {
×
181
            pads_end_vec.push_back(symbolic::zero());
×
182
        }
×
183
    }
×
184

185
    // Get dilations (default to 1 if not provided)
186
    std::vector<symbolic::Expression> dilations_vec;
×
187
    for (size_t i = 0; i < spatial_dims; ++i) {
×
188
        if (i < dilations_.size()) {
×
189
            dilations_vec.push_back(dilations_[i]);
×
190
        } else {
×
191
            dilations_vec.push_back(symbolic::one());
×
192
        }
×
193
    }
×
194

195
    // Get variable names
196
    auto& X_var = x_node->data();
×
197
    auto& W_var = w_node->data();
×
198
    auto& Y_var = y_node->data();
×
199

200
    // Use shape_ for dimensions if available
201
    // For a generic n-dimensional implementation:
202
    // Input X shape: [N, C_in, D0_in, D1_in, ..., Dn_in]
203
    symbolic::Expression N, C_in;
×
204
    std::vector<symbolic::Expression> input_spatial_dims;
×
205

206
    if (shape_.size() >= 2 + spatial_dims) {
×
207
        N = shape_[0];
×
208
        C_in = shape_[1];
×
209
        for (size_t i = 0; i < spatial_dims; ++i) {
×
210
            input_spatial_dims.push_back(shape_[2 + i]);
×
211
        }
×
212
    } else {
×
213
        N = symbolic::symbol(builder.find_new_name("N"));
×
214
        C_in = symbolic::symbol(builder.find_new_name("C_in"));
×
215
        for (size_t i = 0; i < spatial_dims; ++i) {
×
216
            input_spatial_dims.push_back(symbolic::symbol(builder.find_new_name("D" + std::to_string(i) + "_in")));
×
217
        }
×
218
    }
×
219

220
    // Output Channel (C_out) is passed via constructor
221
    auto C_out = output_channels_;
×
222

223
    // Calculate output spatial dimensions
224
    std::vector<symbolic::Expression> output_spatial_dims;
×
225
    for (size_t i = 0; i < spatial_dims; ++i) {
×
226
        // D_out = floor((D_in + pads_begin + pads_end - dilation * (kernel - 1) - 1) / stride) + 1
227
        auto d_in = input_spatial_dims[i];
×
228
        auto pad = symbolic::add(pads_begin_vec[i], pads_end_vec[i]);
×
229
        auto dk = symbolic::mul(dilations_vec[i], symbolic::sub(kernel_shape_[i], symbolic::one()));
×
230
        auto num = symbolic::sub(symbolic::add(d_in, pad), symbolic::add(dk, symbolic::one()));
×
231
        auto d_out = symbolic::add(symbolic::div(num, strides_vec[i]), symbolic::one());
×
232
        output_spatial_dims.push_back(d_out);
×
233
    }
×
234

235
    // Create new sequence for expansion
236
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
×
237

238
    // Create nested map structure for convolution
239
    structured_control_flow::Sequence* current_scope = &new_sequence;
×
240
    std::vector<symbolic::Expression> output_indices;
×
241
    std::vector<symbolic::Expression> output_spatial_vars;
×
242

243
    // Map over batch dimension
244
    std::string n_str = builder.find_new_name("n");
×
245
    builder.add_container(n_str, types::Scalar(types::PrimitiveType::UInt64));
×
246
    auto n_var = symbolic::symbol(n_str);
×
247
    auto& map_n = builder.add_map(
×
248
        *current_scope,
×
249
        n_var,
×
250
        symbolic::Lt(n_var, N),
×
251
        symbolic::zero(),
×
252
        symbolic::add(n_var, symbolic::one()),
×
253
        structured_control_flow::ScheduleType_Sequential::create(),
×
254
        {},
×
255
        block.debug_info()
×
256
    );
×
257
    current_scope = &map_n.root();
×
258
    output_indices.push_back(n_var);
×
259

260
    // Map over output channel dimension
261
    std::string oc_str = builder.find_new_name("oc");
×
262
    builder.add_container(oc_str, types::Scalar(types::PrimitiveType::UInt64));
×
263
    auto oc_var = symbolic::symbol(oc_str);
×
264
    auto& map_oc = builder.add_map(
×
265
        *current_scope,
×
266
        oc_var,
×
267
        symbolic::Lt(oc_var, C_out),
×
268
        symbolic::zero(),
×
269
        symbolic::add(oc_var, symbolic::one()),
×
270
        structured_control_flow::ScheduleType_Sequential::create(),
×
271
        {},
×
272
        block.debug_info()
×
273
    );
×
274
    current_scope = &map_oc.root();
×
275
    output_indices.push_back(oc_var);
×
276

277
    // Map over each output spatial dimension dynamically
278
    for (size_t i = 0; i < spatial_dims; ++i) {
×
279
        std::string od_str = builder.find_new_name("od" + std::to_string(i));
×
280
        builder.add_container(od_str, types::Scalar(types::PrimitiveType::UInt64));
×
281
        auto od_var = symbolic::symbol(od_str);
×
282
        auto& map_od = builder.add_map(
×
283
            *current_scope,
×
284
            od_var,
×
285
            symbolic::Lt(od_var, output_spatial_dims[i]),
×
286
            symbolic::zero(),
×
287
            symbolic::add(od_var, symbolic::one()),
×
288
            structured_control_flow::ScheduleType_Sequential::create(),
×
289
            {},
×
290
            block.debug_info()
×
291
        );
×
292
        current_scope = &map_od.root();
×
293
        output_indices.push_back(od_var);
×
294
        output_spatial_vars.push_back(od_var);
×
295
    }
×
296

297
    // Create accumulator variable for the sum
298
    std::string accum_var = builder.find_new_name("_conv_accum");
×
299
    builder.add_container(accum_var, scalar_type);
×
300

301
    // Initialize accumulator to 0
302
    auto& init_block = builder.add_block(*current_scope, {}, block.debug_info());
×
303
    auto& accum_init = builder.add_access(init_block, accum_var, block.debug_info());
×
304
    auto& zero_const = builder.add_constant(init_block, "0.0", scalar_type, block.debug_info());
×
305
    auto& init_tasklet = builder.add_tasklet(init_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
×
306
    builder.add_computational_memlet(init_block, zero_const, init_tasklet, "_in", {}, scalar_type, block.debug_info());
×
307
    builder.add_computational_memlet(init_block, init_tasklet, "_out", accum_init, {}, scalar_type, block.debug_info());
×
308

309
    // Create nested for loops for input channels and kernel dimensions
310
    // For grouped convolution, each output channel group only reads C_in/group input channels
NEW
311
    auto C_in_per_group = symbolic::div(C_in, group_);
×
312

313
    // For loop over input channels (per group)
314
    std::string ic_str = builder.find_new_name("ic");
×
315
    builder.add_container(ic_str, types::Scalar(types::PrimitiveType::UInt64));
×
316
    auto ic_var = symbolic::symbol(ic_str);
×
317
    auto& for_ic = builder.add_for(
×
318
        *current_scope,
×
319
        ic_var,
×
NEW
320
        symbolic::Lt(ic_var, C_in_per_group),
×
321
        symbolic::zero(),
×
322
        symbolic::add(ic_var, symbolic::one()),
×
323
        {},
×
324
        block.debug_info()
×
325
    );
×
326
    auto* loop_scope = &for_ic.root();
×
327

328
    // For loops over each kernel spatial dimension
329
    std::vector<symbolic::Expression> kernel_vars;
×
330
    for (size_t i = 0; i < spatial_dims; ++i) {
×
331
        std::string k_str = builder.find_new_name("k" + std::to_string(i));
×
332
        builder.add_container(k_str, types::Scalar(types::PrimitiveType::UInt64));
×
333
        auto k_var = symbolic::symbol(k_str);
×
334
        auto& for_k = builder.add_for(
×
335
            *loop_scope,
×
336
            k_var,
×
337
            symbolic::Lt(k_var, kernel_shape_[i]),
×
338
            symbolic::zero(),
×
339
            symbolic::add(k_var, symbolic::one()),
×
340
            {},
×
341
            block.debug_info()
×
342
        );
×
343
        loop_scope = &for_k.root();
×
344
        kernel_vars.push_back(k_var);
×
345
    }
×
346

347
    // Compute indices for input and weight access
348
    // Input index: [n, ic, od0 * stride0 - pad0 + k0 * dilation0, ...]
349
    // Note: taking dilation into account for input index calculation
350
    std::vector<symbolic::Expression> input_spatial_indices;
×
351
    for (size_t i = 0; i < spatial_dims; ++i) {
×
352
        auto k_dilated = symbolic::mul(kernel_vars[i], dilations_vec[i]);
×
353
        auto input_idx = symbolic::
×
354
            add(symbolic::sub(symbolic::mul(output_spatial_vars[i], strides_vec[i]), pads_begin_vec[i]), k_dilated);
×
355
        input_spatial_indices.push_back(input_idx);
×
356
    }
×
357

358
    // Create computation block
359
    auto& comp_block = builder.add_block(*loop_scope, {}, block.debug_info());
×
360

361
    // Access input X[n, ic, input_spatial_indices...]
362
    auto& x_access = builder.add_access(comp_block, X_var, x_node->debug_info());
×
363
    // Access weight W[oc, ic, k0, k1, ...]
364
    auto& w_access = builder.add_access(comp_block, W_var, w_node->debug_info());
×
365
    // Access accumulator
366
    auto& accum_read = builder.add_access(comp_block, accum_var, block.debug_info());
×
367
    auto& accum_write = builder.add_access(comp_block, accum_var, block.debug_info());
×
368

369
    // Create FMA tasklet: accum = accum + x * w
370
    auto& fma_tasklet =
×
371
        builder.add_tasklet(comp_block, data_flow::fp_fma, "_out", {"_in1", "_in2", "_in3"}, block.debug_info());
×
372

373

374
    // Calculate shapes for
375
    // X shape: [N, C_in, D0, D1...]
376
    std::vector<symbolic::Expression> x_shape_vec = {N, C_in};
×
377
    x_shape_vec.insert(x_shape_vec.end(), input_spatial_dims.begin(), input_spatial_dims.end());
×
378

379
    // W shape: [C_out, C_in/group, k0, k1...]
380
    std::vector<symbolic::Expression> w_shape_vec = {C_out, symbolic::div(C_in, group_)};
×
381
    w_shape_vec.insert(w_shape_vec.end(), kernel_shape_.begin(), kernel_shape_.end());
×
382

383
    // Connect edges with subsets
384
    // For grouped conv, compute group index g = oc / (C_out/group), then
385
    // input channel = g * (C_in/group) + ic
386
    // For group=1: g=0, input_channel=ic. For depthwise (group=C): g=oc, input_channel=oc+ic.
NEW
387
    auto C_out_per_group = symbolic::div(C_out, group_);
×
NEW
388
    auto group_idx = symbolic::div(oc_var, C_out_per_group);
×
NEW
389
    auto input_channel_idx = symbolic::add(symbolic::mul(group_idx, C_in_per_group), ic_var);
×
NEW
390
    std::vector<symbolic::Expression> x_indices_vec = {n_var, input_channel_idx};
×
UNCOV
391
    x_indices_vec.insert(x_indices_vec.end(), input_spatial_indices.begin(), input_spatial_indices.end());
×
392

NEW
393
    std::vector<symbolic::Expression> w_indices_vec = {oc_var, ic_var};
×
UNCOV
394
    w_indices_vec.insert(w_indices_vec.end(), kernel_vars.begin(), kernel_vars.end());
×
395

396
    data_flow::Subset x_subset(x_indices_vec);
×
397
    data_flow::Subset w_subset(w_indices_vec);
×
398

399
    builder.add_computational_memlet(
×
400
        comp_block, x_access, fma_tasklet, "_in1", x_subset, x_edge->base_type(), x_edge->debug_info()
×
401
    );
×
402
    builder.add_computational_memlet(
×
403
        comp_block, w_access, fma_tasklet, "_in2", w_subset, w_edge->base_type(), w_edge->debug_info()
×
404
    );
×
405
    builder.add_computational_memlet(comp_block, accum_read, fma_tasklet, "_in3", {}, scalar_type, block.debug_info());
×
406
    builder.add_computational_memlet(comp_block, fma_tasklet, "_out", accum_write, {}, scalar_type, block.debug_info());
×
407

408
    // After all loops, write accumulated result to output (with optional bias)
409
    auto& output_block = builder.add_block(*current_scope, {}, block.debug_info());
×
410
    auto& accum_final = builder.add_access(output_block, accum_var, block.debug_info());
×
411
    auto& y_access = builder.add_access(output_block, Y_var, y_node->debug_info());
×
412

413
    // Y shape: [N, C_out, D0_out, ...]
414
    std::vector<symbolic::Expression> y_shape_vec = {N, C_out};
×
415
    y_shape_vec.insert(y_shape_vec.end(), output_spatial_dims.begin(), output_spatial_dims.end());
×
416

417
    data_flow::Subset y_subset(output_indices);
×
418

419
    if (b_node) {
×
420
        // Add bias: output = accum + bias[oc]
421
        auto& b_access = builder.add_access(output_block, b_node->data(), b_node->debug_info());
×
422
        auto& add_tasklet =
×
423
            builder.add_tasklet(output_block, data_flow::fp_add, "_out", {"_in1", "_in2"}, block.debug_info());
×
424

425
        builder
×
426
            .add_computational_memlet(output_block, accum_final, add_tasklet, "_in1", {}, scalar_type, block.debug_info());
×
427
        builder.add_computational_memlet(
×
428
            output_block, b_access, add_tasklet, "_in2", {oc_var}, b_edge->base_type(), b_edge->debug_info()
×
429
        );
×
430
        builder.add_computational_memlet(
×
431
            output_block, add_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
×
432
        );
×
433
    } else {
×
434
        // No bias: output = accum
435
        auto& assign_tasklet =
×
436
            builder.add_tasklet(output_block, data_flow::assign, "_out", {"_in"}, block.debug_info());
×
437

438
        builder.add_computational_memlet(
×
439
            output_block, accum_final, assign_tasklet, "_in", {}, scalar_type, block.debug_info()
×
440
        );
×
441
        builder.add_computational_memlet(
×
442
            output_block, assign_tasklet, "_out", y_access, y_subset, y_edge.base_type(), y_edge.debug_info()
×
443
        );
×
444
    }
×
445

446
    // Clean up the original block
447
    builder.remove_memlet(block, *x_edge);
×
448
    builder.remove_memlet(block, *w_edge);
×
449
    if (b_edge) {
×
450
        builder.remove_memlet(block, *b_edge);
×
451
        builder.remove_node(block, *b_node);
×
452
    }
×
453
    builder.remove_memlet(block, y_edge);
×
454
    builder.remove_node(block, *x_node);
×
455
    builder.remove_node(block, *w_node);
×
456
    builder.remove_node(block, *y_node);
×
457
    builder.remove_node(block, *this);
×
458
    builder.remove_child(parent, index + 1);
×
459

460
    return true;
×
461
}
×
462

463
symbolic::SymbolSet ConvNode::symbols() const {
×
464
    symbolic::SymbolSet syms;
×
465

466
    for (auto& expr : shape_) {
×
467
        for (auto& atom : symbolic::atoms(expr)) {
×
468
            syms.insert(atom);
×
469
        }
×
470
    }
×
471
    for (auto& expr : kernel_shape_) {
×
472
        for (auto& atom : symbolic::atoms(expr)) {
×
473
            syms.insert(atom);
×
474
        }
×
475
    }
×
476
    for (auto& expr : strides_) {
×
477
        for (auto& atom : symbolic::atoms(expr)) {
×
478
            syms.insert(atom);
×
479
        }
×
480
    }
×
481
    for (auto& expr : pads_) {
×
482
        for (auto& atom : symbolic::atoms(expr)) {
×
483
            syms.insert(atom);
×
484
        }
×
485
    }
×
486
    for (auto& expr : dilations_) {
×
487
        for (auto& atom : symbolic::atoms(expr)) {
×
488
            syms.insert(atom);
×
489
        }
×
490
    }
×
491
    for (auto& atom : symbolic::atoms(output_channels_)) {
×
492
        syms.insert(atom);
×
493
    }
×
494
    for (auto& atom : symbolic::atoms(group_)) {
×
495
        syms.insert(atom);
×
496
    }
×
497

498
    return syms;
×
499
}
×
500

501
void ConvNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
502
    for (auto& expr : shape_) {
×
503
        expr = symbolic::subs(expr, old_expression, new_expression);
×
504
    }
×
505
    for (auto& expr : kernel_shape_) {
×
506
        expr = symbolic::subs(expr, old_expression, new_expression);
×
507
    }
×
508
    for (auto& expr : strides_) {
×
509
        expr = symbolic::subs(expr, old_expression, new_expression);
×
510
    }
×
511
    for (auto& expr : pads_) {
×
512
        expr = symbolic::subs(expr, old_expression, new_expression);
×
513
    }
×
514
    for (auto& expr : dilations_) {
×
515
        expr = symbolic::subs(expr, old_expression, new_expression);
×
516
    }
×
517
    output_channels_ = symbolic::subs(output_channels_, old_expression, new_expression);
×
518
    group_ = symbolic::subs(group_, old_expression, new_expression);
×
519
}
×
520

521
std::unique_ptr<data_flow::DataFlowNode> ConvNode::
522
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
523
    return std::unique_ptr<data_flow::DataFlowNode>(new ConvNode(
×
524
        element_id,
×
525
        this->debug_info(),
×
526
        vertex,
×
527
        parent,
×
528
        shape_,
×
529
        kernel_shape_,
×
530
        strides_,
×
531
        pads_,
×
532
        dilations_,
×
533
        output_channels_,
×
534
        group_
×
535
    ));
×
536
}
×
537

538
std::string ConvNode::toStr() const {
×
539
    std::string result = "Conv(shape=[";
×
540
    for (size_t i = 0; i < shape_.size(); ++i) {
×
541
        if (i > 0) result += ", ";
×
542
        result += shape_[i]->__str__();
×
543
    }
×
544
    result += "], kernel_shape=[";
×
545
    for (size_t i = 0; i < kernel_shape_.size(); ++i) {
×
546
        if (i > 0) result += ", ";
×
547
        result += kernel_shape_[i]->__str__();
×
548
    }
×
549
    result += "], strides=[";
×
550
    for (size_t i = 0; i < strides_.size(); ++i) {
×
551
        if (i > 0) result += ", ";
×
552
        result += strides_[i]->__str__();
×
553
    }
×
554
    result += "], output_channels=" + output_channels_->__str__();
×
555
    result += ", group=" + group_->__str__() + ")";
×
556
    return result;
×
557
}
×
558

559
nlohmann::json ConvNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
560
    const ConvNode& conv_node = static_cast<const ConvNode&>(library_node);
×
561
    nlohmann::json j;
×
562

563
    j["code"] = conv_node.code().value();
×
564

565
    serializer::JSONSerializer serializer;
×
566

567
    j["shape"] = nlohmann::json::array();
×
568
    for (auto& dim : conv_node.shape()) {
×
569
        j["shape"].push_back(serializer.expression(dim));
×
570
    }
×
571

572
    j["kernel_shape"] = nlohmann::json::array();
×
573
    for (auto& dim : conv_node.kernel_shape()) {
×
574
        j["kernel_shape"].push_back(serializer.expression(dim));
×
575
    }
×
576

577
    j["strides"] = nlohmann::json::array();
×
578
    for (auto& stride : conv_node.strides()) {
×
579
        j["strides"].push_back(serializer.expression(stride));
×
580
    }
×
581

582
    j["pads"] = nlohmann::json::array();
×
583
    for (auto& pad : conv_node.pads()) {
×
584
        j["pads"].push_back(serializer.expression(pad));
×
585
    }
×
586

587
    j["dilations"] = nlohmann::json::array();
×
588
    for (auto& dilation : conv_node.dilations()) {
×
589
        j["dilations"].push_back(serializer.expression(dilation));
×
590
    }
×
591

592
    j["output_channels"] = serializer.expression(conv_node.output_channels());
×
593
    j["group"] = serializer.expression(conv_node.group());
×
594

595
    return j;
×
596
}
×
597

598
data_flow::LibraryNode& ConvNodeSerializer::deserialize(
599
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
600
) {
×
601
    assert(j.contains("element_id"));
×
602
    assert(j.contains("code"));
×
603
    assert(j.contains("debug_info"));
×
604
    assert(j.contains("kernel_shape"));
×
605

606
    std::vector<symbolic::Expression> shape;
×
607
    if (j.contains("shape")) {
×
608
        for (const auto& dim : j["shape"]) {
×
609
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
610
        }
×
611
    }
×
612

613
    std::vector<symbolic::Expression> kernel_shape;
×
614
    for (const auto& dim : j["kernel_shape"]) {
×
615
        kernel_shape.push_back(symbolic::parse(dim.get<std::string>()));
×
616
    }
×
617

618
    std::vector<symbolic::Expression> strides;
×
619
    if (j.contains("strides")) {
×
620
        for (const auto& stride : j["strides"]) {
×
621
            strides.push_back(symbolic::parse(stride.get<std::string>()));
×
622
        }
×
623
    }
×
624

625
    std::vector<symbolic::Expression> pads;
×
626
    if (j.contains("pads")) {
×
627
        for (const auto& pad : j["pads"]) {
×
628
            pads.push_back(symbolic::parse(pad.get<std::string>()));
×
629
        }
×
630
    }
×
631

632
    std::vector<symbolic::Expression> dilations;
×
633
    if (j.contains("dilations")) {
×
634
        for (const auto& dilation : j["dilations"]) {
×
635
            dilations.push_back(symbolic::parse(dilation.get<std::string>()));
×
636
        }
×
637
    }
×
638

639
    symbolic::Expression output_channels = symbolic::one();
×
640
    if (j.contains("output_channels")) {
×
641
        output_channels = symbolic::parse(j["output_channels"].get<std::string>());
×
642
    }
×
643

644
    symbolic::Expression group = symbolic::one();
×
645
    if (j.contains("group")) {
×
646
        group = symbolic::parse(j["group"].get<std::string>());
×
647
    }
×
648

649
    sdfg::serializer::JSONSerializer serializer;
×
650
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
651

652
    return builder.add_library_node<
×
653
        ConvNode>(parent, debug_info, shape, kernel_shape, strides, pads, dilations, output_channels, group);
×
654
}
×
655

656
} // namespace tensor
657
} // namespace math
658
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc