• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27981272983

22 Jun 2026 08:18PM UTC coverage: 61.754% (-0.03%) from 61.782%
27981272983

Pull #781

github

web-flow
Merge bddaa3724 into fe87d162b
Pull Request #781: Extend Segformer benchmarks setup

987 of 1432 new or added lines in 62 files covered. (68.92%)

9 existing lines in 7 files now uncovered.

38121 of 61730 relevant lines covered (61.75%)

993.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.95
/sdfg/src/data_flow/library_nodes/math/tensor/einsum_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/einsum_node.h"
2

3
#include <cstddef>
4
#include <memory>
5
#include <nlohmann/json_fwd.hpp>
6
#include <sstream>
7
#include <string>
8
#include <unordered_map>
9
#include <unordered_set>
10
#include <vector>
11

12
#include "sdfg/analysis/analysis.h"
13
#include "sdfg/builder/structured_sdfg_builder.h"
14
#include "sdfg/data_flow/access_node.h"
15
#include "sdfg/data_flow/data_flow_graph.h"
16
#include "sdfg/data_flow/data_flow_node.h"
17
#include "sdfg/data_flow/library_node.h"
18
#include "sdfg/data_flow/library_nodes/math/math_node.h"
19
#include "sdfg/data_flow/memlet.h"
20
#include "sdfg/data_flow/tasklet.h"
21
#include "sdfg/element.h"
22
#include "sdfg/exceptions.h"
23
#include "sdfg/function.h"
24
#include "sdfg/graph/graph.h"
25
#include "sdfg/serializer/json_serializer.h"
26
#include "sdfg/structured_control_flow/block.h"
27
#include "sdfg/structured_control_flow/control_flow_node.h"
28
#include "sdfg/structured_control_flow/map.h"
29
#include "sdfg/structured_control_flow/sequence.h"
30
#include "sdfg/structured_control_flow/structured_loop.h"
31
#include "sdfg/symbolic/symbolic.h"
32
#include "sdfg/types/scalar.h"
33
#include "sdfg/types/type.h"
34
#include "symengine/symbol.h"
35

36
namespace sdfg {
37
namespace math {
38
namespace tensor {
39

40
EinsumNode::EinsumNode(
41
    size_t element_id,
42
    const DebugInfo& debug_info,
43
    const graph::Vertex vertex,
44
    data_flow::DataFlowGraph& parent,
45
    const std::vector<std::string>& inputs,
46
    const std::vector<EinsumDimension>& dims,
47
    const data_flow::Subset& out_indices,
48
    const std::vector<data_flow::Subset>& in_indices,
49
    bool rename_indvars
50
)
51
    : math::MathNode(
54✔
52
          element_id,
54✔
53
          debug_info,
54✔
54
          vertex,
54✔
55
          parent,
54✔
56
          LibraryNodeType_Einsum,
54✔
57
          {"__einsum_out"},
54✔
58
          inputs,
54✔
59
          data_flow::ImplementationType_NONE
54✔
60
      ),
54✔
61
      dims_(dims), out_indices_(out_indices), in_indices_(in_indices) {
54✔
62
    // Check list sizes
63
    if (inputs.size() != in_indices.size()) {
54✔
64
        throw InvalidSDFGException("EinsumNode: Number of input containers != number of input indices");
×
65
    }
×
66

67
    // Rename indvars to internal symbols (only for fresh construction, not clone/deserialize)
68
    if (rename_indvars) {
54✔
69
        // Build mapping from original indvars to internal symbols
70
        // Format: _einsum_node_{element_id}_{original_indvar_name}
71
        std::string prefix = "_einsum_node_" + std::to_string(element_id) + "_";
40✔
72
        std::vector<std::pair<symbolic::Symbol, symbolic::Symbol>> indvar_renames;
40✔
73
        for (const auto& dim : this->dims_) {
40✔
74
            auto old_indvar = dim.indvar;
22✔
75
            auto old_name = SymEngine::rcp_static_cast<const SymEngine::Symbol>(old_indvar)->get_name();
22✔
76
            auto new_indvar = symbolic::symbol(prefix + old_name);
22✔
77
            indvar_renames.push_back({old_indvar, new_indvar});
22✔
78
        }
22✔
79

80
        // Apply all substitutions
81
        for (size_t idx = 0; idx < indvar_renames.size(); idx++) {
62✔
82
            auto old_indvar = indvar_renames[idx].first;
22✔
83
            auto new_indvar = indvar_renames[idx].second;
22✔
84

85
            // Replace in all dims' init, bound, and indvar
86
            for (auto& d : this->dims_) {
56✔
87
                if (symbolic::eq(d.indvar, old_indvar)) {
56✔
88
                    d.indvar = new_indvar;
22✔
89
                }
22✔
90
                d.init = symbolic::subs(d.init, old_indvar, new_indvar);
56✔
91
                d.bound = symbolic::subs(d.bound, old_indvar, new_indvar);
56✔
92
            }
56✔
93

94
            // Replace in out_indices
95
            for (size_t i = 0; i < this->out_indices_.size(); i++) {
57✔
96
                this->out_indices_[i] = symbolic::subs(this->out_indices_[i], old_indvar, new_indvar);
35✔
97
            }
35✔
98

99
            // Replace in in_indices
100
            for (size_t i = 0; i < this->in_indices_.size(); i++) {
62✔
101
                for (size_t j = 0; j < this->in_indices_[i].size(); j++) {
111✔
102
                    this->in_indices_[i][j] = symbolic::subs(this->in_indices_[i][j], old_indvar, new_indvar);
71✔
103
                }
71✔
104
            }
40✔
105
        }
22✔
106
    }
40✔
107

108
    // Append output at the end
109
    this->inputs_.push_back("__einsum_out");
54✔
110
    this->in_indices_.push_back(this->out_indices_);
54✔
111
}
54✔
112

113
const std::vector<EinsumDimension>& EinsumNode::dims() const { return this->dims_; }
165✔
114

115
const EinsumDimension& EinsumNode::dim(size_t index) const { return this->dims_.at(index); }
7✔
116

117
const symbolic::Symbol& EinsumNode::indvar(size_t index) const { return this->dims_.at(index).indvar; }
127✔
118

119
const symbolic::Expression& EinsumNode::init(size_t index) const { return this->dims_.at(index).init; }
40✔
120

121
const symbolic::Expression& EinsumNode::bound(size_t index) const { return this->dims_.at(index).bound; }
58✔
122

123
const data_flow::Subset& EinsumNode::out_indices() const { return this->out_indices_; }
137✔
124

125
const symbolic::Expression& EinsumNode::out_index(size_t index) const { return this->out_indices_.at(index); }
37✔
126

127
const std::vector<data_flow::Subset>& EinsumNode::in_indices() const { return this->in_indices_; }
77✔
128

129
const data_flow::Subset& EinsumNode::in_indices(size_t index) const { return this->in_indices_.at(index); }
96✔
130

131
const symbolic::Expression& EinsumNode::in_index(size_t index1, size_t index2) const {
82✔
132
    return this->in_indices_.at(index1).at(index2);
82✔
133
}
82✔
134

135
symbolic::SymbolSet EinsumNode::internal_symbols() const {
4✔
136
    symbolic::SymbolSet result;
4✔
137
    for (auto& dim : this->dims()) {
9✔
138
        result.insert(dim.indvar);
9✔
139
    }
9✔
140
    return result;
4✔
141
}
4✔
142

143
bool EinsumNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
6✔
144
    // Get data flow graph and block
145
    auto& dfg = this->get_parent();
6✔
146
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
6✔
147
    if (!block) {
6✔
148
        return false;
×
149
    }
×
150

151
    // Get parent sequence
152
    auto* sequence = dynamic_cast<structured_control_flow::Sequence*>(block->get_parent());
6✔
153
    if (!sequence) {
6✔
154
        return false;
×
155
    }
×
156

157
    // Create block after this block
158
    auto& block_after = builder.add_block_after(*sequence, *block, {}, block->debug_info());
6✔
159

160
    // Collect and transfer nodes after the EinsumNode
161
    bool before = true;
6✔
162
    std::unordered_map<data_flow::DataFlowNode*, data_flow::DataFlowNode*> nodes_after;
6✔
163
    for (auto* node : dfg.topological_sort()) {
34✔
164
        if (before) {
34✔
165
            if (node == this) {
23✔
166
                before = false;
6✔
167
            }
6✔
168
            continue;
23✔
169
        }
23✔
170
        data_flow::DataFlowNode* node_after = nullptr;
11✔
171
        if (auto* constant_node = dynamic_cast<data_flow::ConstantNode*>(node)) {
11✔
172
            node_after =
×
173
                &builder
×
174
                     .add_constant(block_after, constant_node->data(), constant_node->type(), constant_node->debug_info());
×
175
        } else if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(node)) {
11✔
176
            if (dfg.out_degree(*access_node) == 0 && dfg.in_degree(*access_node) == 1 &&
9✔
177
                &(*dfg.in_edges(*access_node).begin()).src() == this) {
9✔
178
                continue;
5✔
179
            }
5✔
180
            node_after = &builder.add_access(block_after, access_node->data(), access_node->debug_info());
4✔
181
        } else if (auto* code_node = dynamic_cast<data_flow::CodeNode*>(node)) {
4✔
182
            node_after = &builder.copy_node(block_after, *code_node);
2✔
183
        } else {
2✔
184
            return false;
×
185
        }
×
186
        nodes_after.insert({node, node_after});
6✔
187
        if (dynamic_cast<data_flow::Tasklet*>(node) || dynamic_cast<data_flow::LibraryNode*>(node)) {
6✔
188
            for (auto& iedge : dfg.in_edges(*node)) {
3✔
189
                if (!nodes_after.contains(&iedge.src())) {
3✔
190
                    if (auto* constant_node = dynamic_cast<data_flow::ConstantNode*>(&iedge.src())) {
×
191
                        nodes_after.insert(
×
192
                            {constant_node,
×
193
                             &builder.add_constant(
×
194
                                 block_after, constant_node->data(), constant_node->type(), constant_node->debug_info()
×
195
                             )}
×
196
                        );
×
197
                    } else if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(&iedge.src())) {
×
198
                        nodes_after.insert(
×
199
                            {access_node,
×
200
                             &builder.add_access(block_after, access_node->data(), access_node->debug_info())}
×
201
                        );
×
202
                    } else {
×
203
                        return false;
×
204
                    }
×
205
                }
×
206
            }
3✔
207
        }
2✔
208
    }
6✔
209

210
    // Transfer memlets after the EinsumNode
211
    for (auto& edge : dfg.edges()) {
28✔
212
        if (!nodes_after.contains(&edge.src()) || !nodes_after.contains(&edge.dst())) {
28✔
213
            continue;
23✔
214
        }
23✔
215
        builder.add_memlet(
5✔
216
            block_after,
5✔
217
            *nodes_after[&edge.src()],
5✔
218
            edge.src_conn(),
5✔
219
            *nodes_after[&edge.dst()],
5✔
220
            edge.dst_conn(),
5✔
221
            edge.subset(),
5✔
222
            edge.base_type(),
5✔
223
            edge.debug_info()
5✔
224
        );
5✔
225
    }
5✔
226

227
    // Delete transferred data flow in the original block
228
    std::unordered_set<data_flow::Memlet*> edges_for_removal;
6✔
229
    for (auto& edge : dfg.edges()) {
28✔
230
        if (nodes_after.contains(&edge.src()) && nodes_after.contains(&edge.dst())) {
28✔
231
            edges_for_removal.insert(&edge);
5✔
232
        }
5✔
233
    }
28✔
234
    for (auto* edge : edges_for_removal) {
6✔
235
        builder.remove_memlet(*block, *edge);
5✔
236
    }
5✔
237
    std::unordered_set<data_flow::DataFlowNode*> nodes_for_removal;
6✔
238
    for (auto& node : dfg.nodes()) {
34✔
239
        if (dfg.in_degree(node) == 0 && dfg.out_degree(node) == 0) {
34✔
240
            nodes_for_removal.insert(&node);
5✔
241
        }
5✔
242
    }
34✔
243
    for (auto* node : nodes_for_removal) {
6✔
244
        builder.remove_node(*block, *node);
5✔
245
    }
5✔
246

247
    // Add containers for loop induction variables (symbols already renamed in constructor)
248
    for (size_t i = 0; i < this->dims().size(); i++) {
17✔
249
        auto indvar = this->indvar(i);
11✔
250
        auto indvar_name = SymEngine::rcp_static_cast<const SymEngine::Symbol>(indvar)->get_name();
11✔
251
        if (builder.subject().exists(indvar_name)) {
11✔
252
            continue;
5✔
253
        }
5✔
254
        builder.add_container(indvar_name, types::Scalar(types::PrimitiveType::Int64));
6✔
255
    }
6✔
256

257
    // Add loops
258
    structured_control_flow::Sequence* current_sequence = nullptr;
6✔
259
    bool map = true;
6✔
260
    for (size_t i = 0; i < this->dims().size(); i++) {
17✔
261
        if (map) {
11✔
262
            if (i >= this->out_indices().size() || !symbolic::uses(this->out_index(i), this->indvar(i))) {
11✔
263
                map = false;
6✔
264
            } else {
6✔
265
                for (size_t j = 0; j < i; j++) {
7✔
266
                    if (symbolic::uses(this->init(i), this->indvar(j)) ||
2✔
267
                        symbolic::uses(this->bound(i), this->indvar(j))) {
2✔
268
                        map = false;
×
269
                        break;
×
270
                    }
×
271
                }
2✔
272
            }
5✔
273
        }
11✔
274
        auto indvar = this->indvar(i);
11✔
275
        auto condition = symbolic::Lt(indvar, this->bound(i));
11✔
276
        auto init = this->init(i);
11✔
277
        auto update = symbolic::add(indvar, symbolic::one());
11✔
278
        if (current_sequence) {
11✔
279
            structured_control_flow::StructuredLoop* loop;
5✔
280
            if (map) {
5✔
281
                loop = &builder.add_map(
2✔
282
                    *current_sequence,
2✔
283
                    indvar,
2✔
284
                    condition,
2✔
285
                    init,
2✔
286
                    update,
2✔
287
                    ScheduleType_Sequential::create(),
2✔
288
                    {},
2✔
289
                    this->debug_info()
2✔
290
                );
2✔
291
            } else {
3✔
292
                loop = &builder.add_for(*current_sequence, indvar, condition, init, update, {}, this->debug_info());
3✔
293
            }
3✔
294
            current_sequence = &loop->root();
5✔
295
        } else {
6✔
296
            structured_control_flow::StructuredLoop* loop;
6✔
297
            if (map) {
6✔
298
                loop = &builder.add_map_after(
3✔
299
                    *sequence,
3✔
300
                    *block,
3✔
301
                    indvar,
3✔
302
                    condition,
3✔
303
                    init,
3✔
304
                    update,
3✔
305
                    ScheduleType_Sequential::create(),
3✔
306
                    {},
3✔
307
                    this->debug_info()
3✔
308
                );
3✔
309
            } else {
3✔
310
                loop =
3✔
311
                    &builder.add_for_after(*sequence, *block, indvar, condition, init, update, {}, this->debug_info());
3✔
312
            }
3✔
313
            current_sequence = &loop->root();
6✔
314
        }
6✔
315
    }
11✔
316

317
    // Add new block
318
    structured_control_flow::Block* new_block;
6✔
319
    if (current_sequence) {
6✔
320
        new_block = &builder.add_block(*current_sequence);
6✔
321
    } else {
6✔
322
        new_block = &builder.add_block_after(*sequence, *block, {}, this->debug_info());
×
323
    }
×
324

325
    // Transfer the access nodes of the EinsumNode
326
    std::unordered_map<std::string, data_flow::AccessNode*> new_in_accesses;
6✔
327
    std::unordered_map<std::string, const types::IType&> in_types;
6✔
328
    for (auto& iedge : dfg.in_edges(*this)) {
15✔
329
        in_types.insert({iedge.dst_conn(), iedge.base_type()});
15✔
330
        if (auto* constant_node = dynamic_cast<data_flow::ConstantNode*>(&iedge.src())) {
15✔
331
            new_in_accesses.insert(
×
332
                {iedge.dst_conn(),
×
333
                 &builder
×
334
                      .add_constant(*new_block, constant_node->data(), constant_node->type(), constant_node->debug_info())
×
335
                }
×
336
            );
×
337
        } else if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(&iedge.src())) {
15✔
338
            data_flow::AccessNode* new_access_node = nullptr;
15✔
339
            for (auto [conn, other_access_node] : new_in_accesses) {
15✔
340
                if (access_node->data() == other_access_node->data()) {
13✔
341
                    new_access_node = other_access_node;
×
342
                    break;
×
343
                }
×
344
            }
13✔
345
            if (!new_access_node) {
15✔
346
                new_access_node = &builder.add_access(*new_block, access_node->data(), access_node->debug_info());
15✔
347
            }
15✔
348
            new_in_accesses.insert({iedge.dst_conn(), new_access_node});
15✔
349
        } else {
15✔
350
            return false;
×
351
        }
×
352
    }
15✔
353
    data_flow::AccessNode* new_out_access;
6✔
354
    const types::IType* out_type;
6✔
355
    {
6✔
356
        auto& oedge = *dfg.out_edges(*this).begin();
6✔
357
        out_type = &oedge.base_type();
6✔
358
        if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(&oedge.dst())) {
6✔
359
            new_out_access = &builder.add_access(*new_block, access_node->data(), access_node->debug_info());
6✔
360
        } else {
6✔
361
            return false;
×
362
        }
×
363
    }
6✔
364

365
    // Add computations to the block
366
    if (this->inputs().size() == 1) {
6✔
367
        auto& tasklet =
×
368
            builder.add_tasklet(*new_block, data_flow::TaskletCode::assign, {"_out"}, {"_in0"}, this->debug_info());
×
369
        builder.add_memlet(
×
370
            *new_block,
×
371
            *new_in_accesses.at(this->input(0)),
×
372
            "void",
×
373
            tasklet,
×
374
            "_in0",
×
375
            this->in_indices(0),
×
376
            in_types.at(this->input(0)),
×
377
            this->debug_info()
×
378
        );
×
379
        builder.add_memlet(
×
380
            *new_block, tasklet, "_out", *new_out_access, "void", this->out_indices(), *out_type, this->debug_info()
×
381
        );
×
382
    } else if (this->inputs().size() == 2) {
6✔
383
        auto& tasklet =
4✔
384
            builder
4✔
385
                .add_tasklet(*new_block, data_flow::TaskletCode::fp_add, {"_out"}, {"_in0", "_in1"}, this->debug_info());
4✔
386
        builder.add_memlet(
4✔
387
            *new_block,
4✔
388
            *new_in_accesses.at(this->input(0)),
4✔
389
            "void",
4✔
390
            tasklet,
4✔
391
            "_in0",
4✔
392
            this->in_indices(0),
4✔
393
            in_types.at(this->input(0)),
4✔
394
            this->debug_info()
4✔
395
        );
4✔
396
        builder.add_memlet(
4✔
397
            *new_block,
4✔
398
            *new_in_accesses.at(this->input(1)),
4✔
399
            "void",
4✔
400
            tasklet,
4✔
401
            "_in1",
4✔
402
            this->in_indices(1),
4✔
403
            in_types.at(this->input(1)),
4✔
404
            this->debug_info()
4✔
405
        );
4✔
406
        builder.add_memlet(
4✔
407
            *new_block, tasklet, "_out", *new_out_access, "void", this->out_indices(), *out_type, this->debug_info()
4✔
408
        );
4✔
409
    } else {
4✔
410
        // Build a mapping from original connector names to internal names and indices
411
        std::unordered_map<std::string, data_flow::Subset> in_indices;
2✔
412
        std::unordered_map<std::string, std::string> conn_to_internal;
2✔
413
        for (size_t i = 0; i < this->inputs().size(); i++) {
9✔
414
            in_indices.insert({this->input(i), this->in_indices(i)});
7✔
415
            conn_to_internal.insert({this->input(i), "_in" + std::to_string(i)});
7✔
416
        }
7✔
417
        long long inp;
2✔
418
        for (inp = 0; inp < (long long) this->inputs().size() - 3; inp++) {
3✔
419
            auto tmp = builder.find_new_name();
1✔
420
            auto& tmp_type = builder.add_container(tmp, types::Scalar(in_types.at(this->input(inp)).primitive_type()));
1✔
421
            auto& tmp_access = builder.add_access(*new_block, tmp);
1✔
422
            std::string int_conn0 = conn_to_internal.at(this->input(inp));
1✔
423
            std::string int_conn1 = conn_to_internal.at(this->input(inp + 1));
1✔
424
            auto& tasklet = builder.add_tasklet(
1✔
425
                *new_block, data_flow::TaskletCode::fp_mul, {"_out"}, {int_conn0, int_conn1}, this->debug_info()
1✔
426
            );
1✔
427
            builder.add_memlet(
1✔
428
                *new_block,
1✔
429
                *new_in_accesses.at(this->input(inp)),
1✔
430
                "void",
1✔
431
                tasklet,
1✔
432
                int_conn0,
1✔
433
                in_indices.at(this->input(inp)),
1✔
434
                in_types.at(this->input(inp)),
1✔
435
                this->debug_info()
1✔
436
            );
1✔
437
            builder.add_memlet(
1✔
438
                *new_block,
1✔
439
                *new_in_accesses.at(this->input(inp + 1)),
1✔
440
                "void",
1✔
441
                tasklet,
1✔
442
                int_conn1,
1✔
443
                in_indices.at(this->input(inp + 1)),
1✔
444
                in_types.at(this->input(inp + 1)),
1✔
445
                this->debug_info()
1✔
446
            );
1✔
447
            builder.add_memlet(*new_block, tasklet, "_out", tmp_access, "void", {}, tmp_type, this->debug_info());
1✔
448
            new_in_accesses[this->input(inp + 1)] = &tmp_access;
1✔
449
            in_indices[this->input(inp + 1)].clear();
1✔
450
            in_types.erase(this->input(inp + 1));
1✔
451
            in_types.insert({this->input(inp + 1), tmp_type});
1✔
452
        }
1✔
453
        std::string int_conn0 = conn_to_internal.at(this->input(inp));
2✔
454
        std::string int_conn1 = conn_to_internal.at(this->input(inp + 1));
2✔
455
        std::string int_conn2 = conn_to_internal.at(this->input(inp + 2));
2✔
456
        auto& tasklet = builder.add_tasklet(
2✔
457
            *new_block, data_flow::TaskletCode::fp_fma, {"_out"}, {int_conn0, int_conn1, int_conn2}, this->debug_info()
2✔
458
        );
2✔
459
        builder.add_memlet(
2✔
460
            *new_block,
2✔
461
            *new_in_accesses.at(this->input(inp)),
2✔
462
            "void",
2✔
463
            tasklet,
2✔
464
            int_conn0,
2✔
465
            in_indices.at(this->input(inp)),
2✔
466
            in_types.at(this->input(inp)),
2✔
467
            this->debug_info()
2✔
468
        );
2✔
469
        builder.add_memlet(
2✔
470
            *new_block,
2✔
471
            *new_in_accesses.at(this->input(inp + 1)),
2✔
472
            "void",
2✔
473
            tasklet,
2✔
474
            int_conn1,
2✔
475
            in_indices.at(this->input(inp + 1)),
2✔
476
            in_types.at(this->input(inp + 1)),
2✔
477
            this->debug_info()
2✔
478
        );
2✔
479
        builder.add_memlet(
2✔
480
            *new_block,
2✔
481
            *new_in_accesses.at(this->input(inp + 2)),
2✔
482
            "void",
2✔
483
            tasklet,
2✔
484
            int_conn2,
2✔
485
            in_indices.at(this->input(inp + 2)),
2✔
486
            in_types.at(this->input(inp + 2)),
2✔
487
            this->debug_info()
2✔
488
        );
2✔
489
        builder.add_memlet(
2✔
490
            *new_block, tasklet, "_out", *new_out_access, "void", this->out_indices(), *out_type, this->debug_info()
2✔
491
        );
2✔
492
    }
2✔
493

494
    // Remove EinsumNode and its access nodes and memlets
495
    std::unordered_set<data_flow::AccessNode*> old_accesses;
6✔
496
    while (dfg.in_edges(*this).begin() != dfg.in_edges(*this).end()) {
21✔
497
        auto& iedge = *dfg.in_edges(*this).begin();
15✔
498
        old_accesses.insert(dynamic_cast<data_flow::AccessNode*>(&iedge.src()));
15✔
499
        builder.remove_memlet(*block, iedge);
15✔
500
    }
15✔
501
    while (dfg.out_edges(*this).begin() != dfg.out_edges(*this).end()) {
12✔
502
        auto& oedge = *dfg.out_edges(*this).begin();
6✔
503
        old_accesses.insert(dynamic_cast<data_flow::AccessNode*>(&oedge.dst()));
6✔
504
        builder.remove_memlet(*block, oedge);
6✔
505
    }
6✔
506
    for (auto* old_access : old_accesses) {
21✔
507
        if (dfg.in_degree(*old_access) == 0 && dfg.out_degree(*old_access) == 0) {
21✔
508
            builder.remove_node(*block, *old_access);
20✔
509
        }
20✔
510
    }
21✔
511
    builder.remove_node(*block, *this);
6✔
512

513
    // Remove block before loops if empty
514
    size_t block_index = sequence->index(*block);
6✔
515
    if (dfg.nodes().size() == 0 && sequence->at(block_index).second.empty()) {
6✔
516
        builder.remove_child(*sequence, sequence->index(*block));
5✔
517
    }
5✔
518

519
    // Remove block after loops if empty
520
    if (block_after.dataflow().nodes().size() == 0) {
6✔
521
        builder.remove_child(*sequence, sequence->index(block_after));
5✔
522
    }
5✔
523

524
    return true;
6✔
525
}
6✔
526

527
symbolic::SymbolSet EinsumNode::symbols() const {
4✔
528
    symbolic::SymbolSet result;
4✔
529
    symbolic::SymbolSet internal = this->internal_symbols();
4✔
530

531
    // Collect only external symbols from bounds and init expressions
532
    for (auto& dim : this->dims()) {
9✔
533
        for (auto& symbol : symbolic::atoms(dim.init)) {
9✔
534
            if (!internal.count(symbol)) {
×
535
                result.insert(symbol);
×
536
            }
×
537
        }
×
538
        for (auto& symbol : symbolic::atoms(dim.bound)) {
9✔
539
            if (!internal.count(symbol)) {
9✔
540
                result.insert(symbol);
9✔
541
            }
9✔
542
        }
9✔
543
    }
9✔
544

545
    // Note: indices only contain internal indvars, so skip them
546

547
    return result;
4✔
548
}
4✔
549

550
void EinsumNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
551
    // Skip if old_expression is an internal symbol (indvar)
552
    for (auto& dim : this->dims()) {
×
553
        if (symbolic::eq(dim.indvar, old_expression)) {
×
554
            return; // Internal symbol - do not replace
×
555
        }
×
556
    }
×
557

558
    // Only replace external symbols in bounds/init expressions
559
    for (auto& dim : this->dims_) {
×
560
        dim.init = symbolic::subs(dim.init, old_expression, new_expression);
×
561
        dim.bound = symbolic::subs(dim.bound, old_expression, new_expression);
×
562
    }
×
563

564
    // Note: indices only contain internal indvars, so no substitution needed
565
}
×
566

NEW
567
void EinsumNode::replace(const symbolic::ExpressionMapping& replacements) {
×
568
    // Filter out replacements whose key is an internal symbol (indvar)
NEW
569
    symbolic::ExpressionMapping filtered;
×
NEW
570
    for (auto& pair : replacements) {
×
NEW
571
        bool is_internal = false;
×
NEW
572
        for (auto& dim : this->dims_) {
×
NEW
573
            if (symbolic::eq(dim.indvar, pair.first)) {
×
NEW
574
                is_internal = true;
×
NEW
575
                break;
×
NEW
576
            }
×
NEW
577
        }
×
NEW
578
        if (!is_internal) {
×
NEW
579
            filtered[pair.first] = pair.second;
×
NEW
580
        }
×
NEW
581
    }
×
582

NEW
583
    if (filtered.empty()) {
×
NEW
584
        return;
×
NEW
585
    }
×
586

587
    // Only replace external symbols in bounds/init expressions
NEW
588
    for (auto& dim : this->dims_) {
×
NEW
589
        dim.init = symbolic::subs(dim.init, filtered);
×
NEW
590
        dim.bound = symbolic::subs(dim.bound, filtered);
×
NEW
591
    }
×
592

593
    // Note: indices only contain internal indvars, so no substitution needed
NEW
594
}
×
595

596
std::string EinsumNode::toStr() const {
4✔
597
    std::stringstream stream;
4✔
598

599
    stream << this->output(0);
4✔
600
    for (auto& index : this->out_indices()) {
5✔
601
        stream << "[" << index->__str__() << "]";
5✔
602
    }
5✔
603
    stream << " = ";
4✔
604
    size_t num_inputs = this->inputs().size();
4✔
605
    if (num_inputs > 1) {
4✔
606
        for (size_t i = 0; i < num_inputs - 1; i++) {
10✔
607
            if (i > 0) {
6✔
608
                stream << " * ";
2✔
609
            }
2✔
610
            stream << this->input(i);
6✔
611
            for (auto& index : this->in_indices(i)) {
11✔
612
                stream << "[" << index->__str__() << "]";
11✔
613
            }
11✔
614
        }
6✔
615
        stream << " + ";
4✔
616
    }
4✔
617
    stream << this->input(num_inputs - 1);
4✔
618
    for (auto& index : this->in_indices(num_inputs - 1)) {
5✔
619
        stream << "[" << index->__str__() << "]";
5✔
620
    }
5✔
621

622
    for (auto& dim : this->dims()) {
9✔
623
        stream << " for " << dim.indvar->__str__() << " = " << dim.init->__str__() << " : " << dim.bound->__str__();
9✔
624
    }
9✔
625

626
    return stream.str();
4✔
627
}
4✔
628

629
symbolic::Expression EinsumNode::flop() const {
4✔
630
    symbolic::SymbolMap dim_map;
4✔
631
    symbolic::Expression result = symbolic::one();
4✔
632

633
    for (size_t i = 0; i < this->dims().size(); i++) {
13✔
634
        symbolic::Expression dim_expr = symbolic::sub(this->bound(i), this->init(i));
9✔
635
        for (size_t j = 0; j < i; j++) {
16✔
636
            for (auto& symbol : symbolic::atoms(dim_expr)) {
7✔
637
                if (symbolic::eq(symbol, this->indvar(j))) {
7✔
638
                    dim_expr =
×
639
                        symbolic::subs(dim_expr, symbol, symbolic::div(dim_map.at(symbol), symbolic::integer(2)));
×
640
                }
×
641
            }
7✔
642
        }
7✔
643
        dim_map.insert({this->indvar(i), dim_expr});
9✔
644
        result = symbolic::mul(result, dim_expr);
9✔
645
    }
9✔
646

647
    return symbolic::mul(result, symbolic::integer(this->inputs().size() - 1));
4✔
648
}
4✔
649

650
std::unique_ptr<data_flow::DataFlowNode> EinsumNode::
651
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
652
    return std::make_unique<EinsumNode>(
×
653
        element_id,
×
654
        this->debug_info(),
×
655
        vertex,
×
656
        parent,
×
657
        std::vector<std::string>(this->inputs().begin(), this->inputs().end() - 1),
×
658
        this->dims(),
×
659
        this->out_indices(),
×
660
        std::vector<data_flow::Subset>(this->in_indices().begin(), this->in_indices().end() - 1),
×
661
        false // skip renaming - already internal symbols
×
662
    );
×
663
}
×
664

665
void EinsumNode::validate(const Function& function) const {
22✔
666
    // Check inputs
667
    size_t inputs_size = this->inputs().size();
22✔
668
    if (inputs_size == 0) {
22✔
669
        throw InvalidSDFGException("EinsumNode: Inputs of EinsumNode must not be empty");
×
670
    }
×
671
    for (size_t i = 0; i < inputs_size - 1; i++) {
60✔
672
        if (this->input(i) == "__einsum_out") {
38✔
673
            throw InvalidSDFGException("EinsumNode: Input '__einsum_out' at wrong position");
×
674
        }
×
675
    }
38✔
676
    if (this->input(inputs_size - 1) != "__einsum_out") {
22✔
677
        throw InvalidSDFGException("EinsumNode: Last input of EinsumNode must be '__einsum_out'");
×
678
    }
×
679

680
    // Check last in indices
681
    if (this->out_indices().size() != this->in_indices(inputs_size - 1).size()) {
22✔
682
        throw InvalidSDFGException("EinsumNode: Out indices and last in indices have different sizes");
×
683
    }
×
684
    for (size_t i = 0; i < this->out_indices().size(); i++) {
40✔
685
        if (!symbolic::eq(this->out_index(i), this->in_index(inputs_size - 1, i))) {
18✔
686
            throw InvalidSDFGException("EinsumNode: Out indices and last in indices do not match");
×
687
        }
×
688
    }
18✔
689

690
    // Check input containers
691
    auto& dfg = this->get_parent();
22✔
692
    auto& oedge = *dfg.out_edges(*this).begin();
22✔
693
    std::string out_container = dynamic_cast<const data_flow::AccessNode&>(oedge.dst()).data();
22✔
694
    for (auto& iedge : dfg.in_edges(*this)) {
60✔
695
        auto& src = dynamic_cast<const data_flow::AccessNode&>(iedge.src());
60✔
696
        if (src.data() != out_container && iedge.dst_conn() == "__einsum_out") {
60✔
697
            throw InvalidSDFGException("EinsumNode: Out container must occur as a summation in the inputs");
×
698
        }
×
699
    }
60✔
700

701
    // Check if dimensions index variables occur at least once as in/out indices
702
    for (size_t i = 0; i < this->dims().size(); i++) {
45✔
703
        bool unused = true;
23✔
704
        for (auto& index : this->out_indices()) {
30✔
705
            for (auto& symbol : symbolic::atoms(index)) {
30✔
706
                if (symbolic::eq(this->indvar(i), symbol)) {
30✔
707
                    unused = false;
13✔
708
                    break;
13✔
709
                }
13✔
710
            }
30✔
711
            if (!unused) {
30✔
712
                break;
13✔
713
            }
13✔
714
        }
30✔
715
        if (!unused) {
23✔
716
            continue;
13✔
717
        }
13✔
718
        for (auto& indices : this->in_indices()) {
12✔
719
            for (auto& index : indices) {
17✔
720
                for (auto& symbol : symbolic::atoms(index)) {
17✔
721
                    if (symbolic::eq(this->indvar(i), symbol)) {
17✔
722
                        unused = false;
10✔
723
                        break;
10✔
724
                    }
10✔
725
                }
17✔
726
                if (!unused) {
17✔
727
                    break;
10✔
728
                }
10✔
729
            }
17✔
730
            if (!unused) {
12✔
731
                break;
10✔
732
            }
10✔
733
        }
12✔
734
        if (unused) {
10✔
735
            throw InvalidSDFGException(
×
736
                "EinsumNode: Dimension indvar does not occur in the in/out indices: " + this->indvar(i)->__str__()
×
737
            );
×
738
        }
×
739
    }
10✔
740
}
22✔
741

742
nlohmann::json EinsumSerializer::serialize(const data_flow::LibraryNode& libnode) {
1✔
743
    if (libnode.code() != LibraryNodeType_Einsum) {
1✔
744
        throw InvalidSDFGException("EinsumSerializer: Invalid library node type");
×
745
    }
×
746

747
    const auto& einsum_node = static_cast<const EinsumNode&>(libnode);
1✔
748
    serializer::JSONSymbolicPrinter printer;
1✔
749

750
    nlohmann::json j;
1✔
751
    j["type"] = "library_node";
1✔
752
    j["code"] = std::string(LibraryNodeType_Einsum.value());
1✔
753
    j["side_effect"] = einsum_node.side_effect();
1✔
754

755
    j["output"] = einsum_node.output(0);
1✔
756

757
    j["inputs"] = nlohmann::json::array();
1✔
758
    for (auto& input : einsum_node.inputs()) {
3✔
759
        j["inputs"].push_back(input);
3✔
760
    }
3✔
761

762
    j["dims"] = nlohmann::json::array();
1✔
763
    for (auto& dim : einsum_node.dims()) {
3✔
764
        nlohmann::json dimj;
3✔
765
        dimj["indvar"] = printer.apply(dim.indvar);
3✔
766
        dimj["init"] = printer.apply(dim.init);
3✔
767
        dimj["bound"] = printer.apply(dim.bound);
3✔
768
        j["dims"].push_back(dimj);
3✔
769
    }
3✔
770

771
    j["out_indices"] = nlohmann::json::array();
1✔
772
    for (auto& index : einsum_node.out_indices()) {
2✔
773
        j["out_indices"].push_back(printer.apply(index));
2✔
774
    }
2✔
775

776
    j["in_indices"] = nlohmann::json::array();
1✔
777
    for (auto& indices : einsum_node.in_indices()) {
3✔
778
        nlohmann::json indicesj = nlohmann::json::array();
3✔
779
        for (auto& index : indices) {
6✔
780
            indicesj.push_back(printer.apply(index));
6✔
781
        }
6✔
782
        j["in_indices"].push_back(indicesj);
3✔
783
    }
3✔
784

785
    return j;
1✔
786
}
1✔
787

788
data_flow::LibraryNode& EinsumSerializer::deserialize(
789
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
790
) {
1✔
791
    assert(j.contains("type"));
1✔
792
    assert(j["type"].is_string());
1✔
793
    assert(j.contains("code"));
1✔
794
    assert(j["code"].is_string());
1✔
795
    assert(j.contains("side_effect"));
1✔
796
    assert(j["side_effect"].is_boolean());
1✔
797
    assert(j.contains("output"));
1✔
798
    assert(j["output"].is_string());
1✔
799
    assert(j.contains("inputs"));
1✔
800
    assert(j["inputs"].is_array());
1✔
801
    assert(j.contains("dims"));
1✔
802
    assert(j["dims"].is_array());
1✔
803
    assert(j.contains("out_indices"));
1✔
804
    assert(j["out_indices"].is_array());
1✔
805
    assert(j.contains("in_indices"));
1✔
806
    assert(j["in_indices"].is_array());
1✔
807
    assert(j["inputs"].size() == j["in_indices"].size());
1✔
808

809
    auto type = j["type"].get<std::string>();
1✔
810
    if (type != "library_node") {
1✔
811
        throw InvalidSDFGException("EinsumSerializer: Invalid library node type");
×
812
    }
×
813

814
    auto code = j["code"].get<std::string>();
1✔
815
    if (code != LibraryNodeType_Einsum.value()) {
1✔
816
        throw InvalidSDFGException("EinsumSerializer: Invalid library node code");
×
817
    }
×
818

819
    auto side_effect = j["side_effect"].get<bool>();
1✔
820
    if (side_effect) {
1✔
821
        throw InvalidSDFGException("EinsumSerializer: EinsumNodes must be free of side effects");
×
822
    }
×
823

824
    auto output = j["output"].get<std::string>();
1✔
825
    if (output != "__einsum_out") {
1✔
826
        throw InvalidSDFGException("EinsumSerializer: Output of EinsumNode must be '__einsum_out'");
×
827
    }
×
828

829
    auto inputs = j["inputs"].get<std::vector<std::string>>();
1✔
830
    size_t inputs_size = inputs.size();
1✔
831
    if (inputs_size == 0) {
1✔
832
        throw InvalidSDFGException("EinsumSerializer: Inputs of EinsumNode must not be empty");
×
833
    }
×
834
    if (inputs[inputs_size - 1] != "__einsum_out") {
1✔
835
        throw InvalidSDFGException("EinsumSerializer: Last input of EinsumNode must be '__einsum_out'");
×
836
    }
×
837

838
    std::vector<EinsumDimension> dims;
1✔
839
    for (size_t i = 0; i < j["dims"].size(); i++) {
4✔
840
        auto& dimj = j["dims"][i];
3✔
841
        assert(dimj.is_object());
3✔
842
        assert(dimj.contains("indvar"));
3✔
843
        assert(dimj["indvar"].is_string());
3✔
844
        assert(dimj.contains("init"));
3✔
845
        assert(dimj["init"].is_string());
3✔
846
        assert(dimj.contains("bound"));
3✔
847
        assert(dimj["bound"].is_string());
3✔
848

849
        EinsumDimension dim;
3✔
850
        dim.indvar = symbolic::symbol(dimj["indvar"]);
3✔
851
        dim.init = symbolic::parse(dimj["init"]);
3✔
852
        dim.bound = symbolic::parse(dimj["bound"]);
3✔
853
        dims.push_back(dim);
3✔
854
    }
3✔
855

856
    data_flow::Subset out_indices;
1✔
857
    auto out_indices_str = j["out_indices"].get<std::vector<std::string>>();
1✔
858
    for (auto& index_str : out_indices_str) {
2✔
859
        out_indices.push_back(symbolic::parse(index_str));
2✔
860
    }
2✔
861

862
    std::vector<data_flow::Subset> in_indices;
1✔
863
    for (size_t i = 0; i < j["in_indices"].size(); i++) {
4✔
864
        assert(j["in_indices"][i].is_array());
3✔
865

866
        data_flow::Subset indices;
3✔
867
        auto indices_str = j["in_indices"][i].get<std::vector<std::string>>();
3✔
868
        for (auto& index_str : indices_str) {
6✔
869
            indices.push_back(symbolic::parse(index_str));
6✔
870
        }
6✔
871
        in_indices.push_back(indices);
3✔
872
    }
3✔
873
    if (out_indices.size() != in_indices[inputs_size - 1].size()) {
1✔
874
        throw InvalidSDFGException("EinsumSerializer: Out indices and last in indices have different sizes");
×
875
    }
×
876
    for (size_t i = 0; i < out_indices.size(); i++) {
3✔
877
        if (!symbolic::eq(out_indices[i], in_indices[inputs_size - 1][i])) {
2✔
878
            throw InvalidSDFGException("EinsumSerializer: Out indices and last in indices do not match");
×
879
        }
×
880
    }
2✔
881

882
    auto& einsum_node = builder.add_library_node<
1✔
883
        EinsumNode,
1✔
884
        const std::vector<std::string>&,
1✔
885
        const std::vector<EinsumDimension>&,
1✔
886
        const data_flow::Subset&,
1✔
887
        const std::vector<data_flow::Subset>&,
1✔
888
        bool>(
1✔
889
        parent,
1✔
890
        DebugInfo(),
1✔
891
        std::vector<std::string>(inputs.begin(), inputs.end() - 1),
1✔
892
        dims,
1✔
893
        out_indices,
1✔
894
        std::vector<data_flow::Subset>(in_indices.begin(), in_indices.end() - 1),
1✔
895
        false // skip renaming - already internal symbols from serialization
1✔
896
    );
1✔
897

898
    return einsum_node;
1✔
899
}
1✔
900

901
} // namespace tensor
902
} // namespace math
903
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc