• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23884192681

01 Apr 2026 08:03PM UTC coverage: 64.453% (+0.06%) from 64.398%
23884192681

push

github

web-flow
Merge pull request #626 from daisytuner/einsum-conversion

adds einsum conversion to expand bindings

2 of 11 new or added lines in 4 files covered. (18.18%)

2 existing lines in 2 files now uncovered.

28779 of 44651 relevant lines covered (64.45%)

386.15 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/opt/src/transformations/einsum2dot.cpp
1
#include "sdfg/transformations/einsum2dot.h"
2

3
#include <cstddef>
4
#include <nlohmann/json_fwd.hpp>
5
#include <string>
6

7
#include "sdfg/analysis/analysis.h"
8
#include "sdfg/builder/structured_sdfg_builder.h"
9
#include "sdfg/data_flow/access_node.h"
10
#include "sdfg/data_flow/library_node.h"
11
#include "sdfg/data_flow/library_nodes/math/blas/dot_node.h"
12
#include "sdfg/data_flow/library_nodes/math/math.h"
13
#include "sdfg/einsum/einsum.h"
14
#include "sdfg/optimization_report/pass_report_consumer.h"
15
#include "sdfg/structured_control_flow/block.h"
16
#include "sdfg/symbolic/symbolic.h"
17
#include "sdfg/targets/cuda/cuda.h"
18
#include "sdfg/transformations/transformation.h"
19
#include "sdfg/types/type.h"
20
#include "sdfg/types/utils.h"
21

22
namespace sdfg {
23
namespace transformations {
24

NEW
25
Einsum2Dot::Einsum2Dot(einsum::EinsumNode& einsum_node) : einsum_node_(einsum_node) {}
×
26

27
std::string Einsum2Dot::name() const { return "Einsum2Dot"; }
×
28

29
bool Einsum2Dot::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
30
    // Check dims
31
    if (this->einsum_node_.dims().size() != 1 || !symbolic::eq(this->einsum_node_.init(0), symbolic::zero())) {
×
32
        return false;
×
33
    }
×
34
    symbolic::Symbol indvar = this->einsum_node_.indvar(0);
×
35

36
    // Check out indices
37
    if (this->einsum_node_.out_indices().size() != 0) {
×
38
        return false;
×
39
    }
×
40

41
    // Check inputs
42
    if (this->einsum_node_.inputs().size() != 3 || this->einsum_node_.input(2) != this->einsum_node_.output(0)) {
×
43
        return false;
×
44
    }
×
45

46
    // Check in indices
47
    if (this->einsum_node_.in_indices(0).size() != 1 || !symbolic::eq(this->einsum_node_.in_index(0, 0), indvar)) {
×
48
        return false;
×
49
    }
×
50
    if (this->einsum_node_.in_indices(1).size() != 1 || !symbolic::eq(this->einsum_node_.in_index(1, 0), indvar)) {
×
51
        return false;
×
52
    }
×
53

54
    // Get the data flow graph
55
    auto& dfg = this->einsum_node_.get_parent();
×
56

57
    // Determine and check the base type of output
58
    auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
×
59
    auto data_type = oedge.base_type().primitive_type();
×
60
    if (data_type != types::PrimitiveType::Float && data_type != types::PrimitiveType::Double) {
×
61
        return false;
×
62
    }
×
63

64
    // Check if all inputs have the same primitive type
65
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
×
66
        if (iedge.base_type().primitive_type() != data_type) {
×
67
            return false;
×
68
        }
×
69
    }
×
70

71
    return true;
×
72
}
×
73

74
void Einsum2Dot::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
75
    // Get the data flow graph
76
    auto& dfg = this->einsum_node_.get_parent();
×
77

78
    // Get the block in which the einsum node lives
79
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
×
80
    assert(block);
×
81

82
    // Get the number of iterations (n)
83
    symbolic::Expression n = this->einsum_node_.bound(0);
×
84

85
    // Determine the BLAS precision
86
    math::blas::BLAS_Precision precision;
×
87
    auto& datatype_oedge = *dfg.out_edges(this->einsum_node_).begin();
×
88
    types::PrimitiveType data_type = datatype_oedge.base_type().primitive_type();
×
89
    if (data_type == types::PrimitiveType::Float) {
×
90
        precision = math::blas::BLAS_Precision::s;
×
91
    } else {
×
92
        precision = math::blas::BLAS_Precision::d;
×
93
    }
×
94

95
    // Add the dot node
96
    auto& libnode = builder.add_library_node<
×
97
        math::blas::DotNode,
×
98
        const data_flow::ImplementationType&,
×
99
        const math::blas::BLAS_Precision&,
×
100
        symbolic::Expression>(
×
NEW
101
        *block, this->einsum_node_.debug_info(), sdfg::math::blas::ImplementationType_BLAS, precision, n
×
102
    );
×
103

104
    // Copy the memlets
105
    data_flow::AccessNode* leftover_access_node = nullptr;
×
106
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
×
107
        if (iedge.dst_conn() == this->einsum_node_.input(0)) {
×
108
            builder.add_memlet(
×
109
                *block,
×
110
                iedge.src(),
×
111
                iedge.src_conn(),
×
112
                libnode,
×
113
                "__x",
×
114
                iedge.subset(),
×
115
                iedge.base_type(),
×
116
                iedge.debug_info()
×
117
            );
×
118
        } else if (iedge.dst_conn() == this->einsum_node_.input(1)) {
×
119
            builder.add_memlet(
×
120
                *block,
×
121
                iedge.src(),
×
122
                iedge.src_conn(),
×
123
                libnode,
×
124
                "__y",
×
125
                iedge.subset(),
×
126
                iedge.base_type(),
×
127
                iedge.debug_info()
×
128
            );
×
129
        } else if (iedge.dst_conn() == this->einsum_node_.input(2)) {
×
130
            leftover_access_node = dynamic_cast<data_flow::AccessNode*>(&iedge.src());
×
131
        }
×
132
    }
×
133
    for (auto& oedge : dfg.out_edges(this->einsum_node_)) {
×
134
        if (oedge.src_conn() == this->einsum_node_.output(0)) {
×
135
            builder.add_memlet(
×
136
                *block,
×
137
                libnode,
×
138
                "__out",
×
139
                oedge.dst(),
×
140
                oedge.dst_conn(),
×
141
                oedge.subset(),
×
142
                oedge.base_type(),
×
143
                oedge.debug_info()
×
144
            );
×
145
        }
×
146
    }
×
147

148
    // Remove the old memlets
149
    while (dfg.in_edges(this->einsum_node_).begin() != dfg.in_edges(this->einsum_node_).end()) {
×
150
        auto& iedge = *dfg.in_edges(this->einsum_node_).begin();
×
151
        builder.remove_memlet(*block, iedge);
×
152
    }
×
153
    while (dfg.out_edges(this->einsum_node_).begin() != dfg.out_edges(this->einsum_node_).end()) {
×
154
        auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
×
155
        builder.remove_memlet(*block, oedge);
×
156
    }
×
157

158
    // Remove leftover access node
159
    builder.remove_node(*block, *leftover_access_node);
×
160

161
    // Remove the einsum node
162
    builder.remove_node(*block, this->einsum_node_);
×
163

164
    analysis_manager.invalidate_all();
×
165
}
×
166

167
void Einsum2Dot::to_json(nlohmann::json& j) const {
×
168
    j["transformation_type"] = this->name();
×
169
    j["einsum_node_element_id"] = this->einsum_node_.element_id();
×
170
}
×
171

172
Einsum2Dot Einsum2Dot::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& j) {
×
173
    assert(j.contains("einsum_node_element_id"));
×
174
    assert(j["einsum_node_element_id"].is_number_unsigned());
×
175

UNCOV
176
    size_t einsum_node_id = j["einsum_node_element_id"].get<size_t>();
×
177
    auto* einsum_node_element = builder.find_element_by_id(einsum_node_id);
×
178
    if (!einsum_node_element) {
×
179
        throw InvalidTransformationDescriptionException(
×
180
            "Element with ID " + std::to_string(einsum_node_id) + " not found"
×
181
        );
×
182
    }
×
183
    auto* einsum_node = dynamic_cast<einsum::EinsumNode*>(einsum_node_element);
×
184
    if (!einsum_node) {
×
185
        throw InvalidTransformationDescriptionException(
×
186
            "Element with ID " + std::to_string(einsum_node_id) + " is not an EinsumNode"
×
187
        );
×
188
    }
×
189

NEW
190
    return Einsum2Dot(*einsum_node);
×
191
}
×
192

193
} // namespace transformations
194
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc