• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23500309960

24 Mar 2026 04:23PM UTC coverage: 64.456% (+0.2%) from 64.295%
23500309960

Pull #605

github

web-flow
Merge 732825a5e into e56781552
Pull Request #605: Move einsum support

1303 of 1918 new or added lines in 14 files covered. (67.94%)

2 existing lines in 1 file now uncovered.

27952 of 43366 relevant lines covered (64.46%)

392.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.33
/opt/src/transformations/einsum_extend.cpp
1
#include "sdfg/transformations/einsum_extend.h"
2

3
#include <cstddef>
4
#include <nlohmann/json_fwd.hpp>
5
#include <string>
6
#include <tuple>
7
#include <unordered_map>
8
#include <unordered_set>
9
#include <vector>
10

11
#include "sdfg/analysis/analysis.h"
12
#include "sdfg/builder/structured_sdfg_builder.h"
13
#include "sdfg/data_flow/access_node.h"
14
#include "sdfg/data_flow/data_flow_node.h"
15
#include "sdfg/data_flow/memlet.h"
16
#include "sdfg/data_flow/tasklet.h"
17
#include "sdfg/einsum/einsum.h"
18
#include "sdfg/element.h"
19
#include "sdfg/structured_control_flow/block.h"
20
#include "sdfg/transformations/transformation.h"
21
#include "sdfg/types/type.h"
22

23
namespace sdfg {
24
namespace transformations {
25

26
EinsumExtend::EinsumExtend(einsum::EinsumNode& einsum_node) : einsum_node_(einsum_node) {}
15✔
27

NEW
28
std::string EinsumExtend::name() const { return "EinsumExtend"; }
×
29

30
bool EinsumExtend::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
15✔
31
    // Skip EinsumNodes with dimensions
32
    if (this->einsum_node_.dims().size() > 0) {
15✔
NEW
33
        return false;
×
NEW
34
    }
×
35

36
    size_t muls = 0;
15✔
37
    auto& dfg = this->einsum_node_.get_parent();
15✔
38
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
38✔
39
        // Skip reduction container (connector "__einsum_out")
40
        if (iedge.dst_conn() == this->einsum_node_.inputs().back()) {
38✔
41
            continue;
15✔
42
        }
15✔
43

44
        // Skip constant nodes and access nodes without in edges
45
        auto& access_node = static_cast<data_flow::AccessNode&>(iedge.src());
23✔
46
        if (dynamic_cast<data_flow::ConstantNode*>(&access_node) || dfg.in_degree(access_node) == 0) {
23✔
47
            continue;
13✔
48
        }
13✔
49

50
        // Count the multiplication tasklets whose output access nodes are input access nodes of the EinsumNode
51
        for (auto& access_node_iedge : dfg.in_edges(access_node)) {
10✔
52
            auto* tasklet = dynamic_cast<data_flow::Tasklet*>(&access_node_iedge.src());
10✔
53
            if (tasklet && tasklet->code() == data_flow::TaskletCode::fp_mul) {
10✔
54
                muls++;
10✔
55
            }
10✔
56
        }
10✔
57
    }
10✔
58

59
    return muls > 0;
15✔
60
}
15✔
61

62
void EinsumExtend::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
9✔
63
    auto& dfg = this->einsum_node_.get_parent();
9✔
64
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
9✔
65
    assert(block);
9✔
66

67
    // Construct inputs, in indices, and a map for new in edges
68
    std::vector<std::string> inputs;
9✔
69
    std::vector<data_flow::Subset> in_indices;
9✔
70
    std::unordered_map<std::string, std::tuple<data_flow::AccessNode&, const types::IType&, DebugInfo>> new_iedges_map;
9✔
71
    std::unordered_set<data_flow::Memlet*> memlets_for_removal;
9✔
72
    std::unordered_set<data_flow::DataFlowNode*> nodes_for_removal;
9✔
73
    DebugInfo new_deb_info(this->einsum_node_.debug_info());
9✔
74
    for (size_t i = 0; i < this->einsum_node_.inputs().size() - 1; i++) {
26✔
75
        auto& conn = this->einsum_node_.input(i);
17✔
76

77
        // Find corresponding in edge
78
        data_flow::Memlet* iedge = nullptr;
17✔
79
        for (auto& in_edge : dfg.in_edges(this->einsum_node_)) {
26✔
80
            if (in_edge.dst_conn() == conn) {
26✔
81
                iedge = &in_edge;
17✔
82
                break;
17✔
83
            }
17✔
84
        }
26✔
85
        assert(iedge);
17✔
86

87
        // Check if at the access node there is a multiplication tasklet
88
        auto& access_node = static_cast<data_flow::AccessNode&>(iedge->src());
17✔
89
        data_flow::Tasklet* tasklet = nullptr;
17✔
90
        data_flow::Memlet* tasklet_oedge = nullptr;
17✔
91
        if (!dynamic_cast<data_flow::ConstantNode*>(&access_node) && dfg.in_degree(access_node) > 0) {
17✔
92
            for (auto& access_node_iedge : dfg.in_edges(access_node)) {
10✔
93
                auto* tmp_taskelt = dynamic_cast<data_flow::Tasklet*>(&access_node_iedge.src());
10✔
94
                if (tmp_taskelt && tmp_taskelt->code() == data_flow::TaskletCode::fp_mul) {
10✔
95
                    tasklet = tmp_taskelt;
10✔
96
                    tasklet_oedge = &access_node_iedge;
10✔
97
                    break;
10✔
98
                }
10✔
99
            }
10✔
100
        }
10✔
101

102
        // Fill the data ...
103
        if (tasklet) {
17✔
104
            // ... with new access nodes and connectors
105
            for (auto& tasklet_iedge : dfg.in_edges(*tasklet)) {
20✔
106
                auto& tasklet_access_node = static_cast<data_flow::AccessNode&>(tasklet_iedge.src());
20✔
107
                std::string new_conn = iedge->dst_conn() + tasklet_iedge.dst_conn();
20✔
108
                inputs.push_back(new_conn);
20✔
109
                in_indices.push_back(tasklet_iedge.subset());
20✔
110
                new_iedges_map.insert(
20✔
111
                    {new_conn,
20✔
112
                     {tasklet_access_node,
20✔
113
                      tasklet_iedge.base_type(),
20✔
114
                      DebugInfo::merge(iedge->debug_info(), tasklet_iedge.debug_info())}}
20✔
115
                );
20✔
116
                memlets_for_removal.insert(&tasklet_iedge);
20✔
117
            }
20✔
118

119
            // Mark tasklet and its memlets for removal
120
            memlets_for_removal.insert(tasklet_oedge);
10✔
121
            new_deb_info = DebugInfo::merge(new_deb_info, tasklet->debug_info());
10✔
122
            nodes_for_removal.insert(tasklet);
10✔
123

124
            // Mark acess node for removal if not used elsewhere
125
            if (dfg.in_degree(access_node) == 1 && dfg.out_degree(access_node) == 1) {
10✔
126
                nodes_for_removal.insert(&access_node);
10✔
127
            }
10✔
128
        } else {
10✔
129
            // ... with the old stuff
130
            inputs.push_back(conn);
7✔
131
            in_indices.push_back(this->einsum_node_.in_indices(i));
7✔
132
            new_iedges_map.insert({conn, {access_node, iedge->base_type(), iedge->debug_info()}});
7✔
133
        }
7✔
134

135
        // Mark in edge for removal
136
        memlets_for_removal.insert(iedge);
17✔
137
    }
17✔
138

139
    // Special handling for the reduction input
140
    {
9✔
141
        auto& conn = this->einsum_node_.inputs().back();
9✔
142

143
        // Find corresponding in edge
144
        data_flow::Memlet* iedge = nullptr;
9✔
145
        for (auto& in_edge : dfg.in_edges(this->einsum_node_)) {
26✔
146
            if (in_edge.dst_conn() == conn) {
26✔
147
                iedge = &in_edge;
9✔
148
                break;
9✔
149
            }
9✔
150
        }
26✔
151
        assert(iedge);
9✔
152

153
        // Mapping and marking for removal
154
        auto& access_node = static_cast<data_flow::AccessNode&>(iedge->src());
9✔
155
        new_iedges_map.insert({conn, {access_node, iedge->base_type(), iedge->debug_info()}});
9✔
156
        memlets_for_removal.insert(iedge);
9✔
157
    }
9✔
158

159
    // Create new einsum node
NEW
160
    auto& new_libnode = builder.add_library_node<
×
161
        einsum::EinsumNode,
9✔
162
        const std::vector<std::string>&,
9✔
163
        const std::vector<einsum::EinsumDimension>&,
9✔
164
        const data_flow::Subset&,
9✔
165
        const std::vector<
9✔
166
            data_flow::Subset>&>(*block, new_deb_info, inputs, {}, this->einsum_node_.out_indices(), in_indices);
9✔
167
    auto& new_einsum_node = static_cast<einsum::EinsumNode&>(new_libnode);
9✔
168

169
    // Construct in edges
170
    for (auto& conn : new_einsum_node.inputs()) {
36✔
171
        auto [access_node, type, deb_info] = new_iedges_map.at(conn);
36✔
172
        builder.add_memlet(*block, access_node, "void", new_libnode, conn, {}, type, deb_info);
36✔
173
    }
36✔
174

175
    // Remove marked memlets & nodes
176
    for (auto* memlet : memlets_for_removal) {
56✔
177
        builder.remove_memlet(*block, *memlet);
56✔
178
    }
56✔
179
    for (auto* node : nodes_for_removal) {
20✔
180
        builder.remove_node(*block, *node);
20✔
181
    }
20✔
182

183
    // Redirect out edges
184
    while (dfg.out_edges(this->einsum_node_).begin() != dfg.out_edges(this->einsum_node_).end()) {
18✔
185
        auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
9✔
186
        builder.add_memlet(
9✔
187
            *block,
9✔
188
            new_libnode,
9✔
189
            oedge.src_conn(),
9✔
190
            oedge.dst(),
9✔
191
            oedge.dst_conn(),
9✔
192
            oedge.subset(),
9✔
193
            oedge.base_type(),
9✔
194
            oedge.debug_info()
9✔
195
        );
9✔
196
        builder.remove_memlet(*block, oedge);
9✔
197
    }
9✔
198

199
    // Remove old einsum node
200
    builder.remove_node(*block, this->einsum_node_);
9✔
201

202
    analysis_manager.invalidate_all();
9✔
203
}
9✔
204

NEW
205
void EinsumExtend::to_json(nlohmann::json& j) const {
×
NEW
206
    j["transformation_type"] = this->name();
×
NEW
207
    j["einsum_node_element_id"] = this->einsum_node_.element_id();
×
NEW
208
}
×
209

NEW
210
EinsumExtend EinsumExtend::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& j) {
×
NEW
211
    assert(j.contains("einsum_node_element_id"));
×
NEW
212
    assert(j["einsum_node_element_id"].is_number_unsigned());
×
NEW
213
    size_t einsum_node_id = j["einsum_node_element_id"].get<size_t>();
×
NEW
214
    auto* einsum_node_element = builder.find_element_by_id(einsum_node_id);
×
NEW
215
    if (!einsum_node_element) {
×
NEW
216
        throw InvalidTransformationDescriptionException(
×
NEW
217
            "Element with ID " + std::to_string(einsum_node_id) + " not found"
×
NEW
218
        );
×
NEW
219
    }
×
NEW
220
    auto* einsum_node = dynamic_cast<einsum::EinsumNode*>(einsum_node_element);
×
NEW
221
    if (!einsum_node) {
×
NEW
222
        throw InvalidTransformationDescriptionException(
×
NEW
223
            "Element with ID " + std::to_string(einsum_node_id) + " is not an EinsumNode"
×
NEW
224
        );
×
NEW
225
    }
×
226

NEW
227
    return EinsumExtend(*einsum_node);
×
NEW
228
}
×
229

230
} // namespace transformations
231
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc