• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23492793883

24 Mar 2026 01:47PM UTC coverage: 64.456% (+0.2%) from 64.295%
23492793883

Pull #605

github

web-flow
Merge fdd141272 into e56781552
Pull Request #605: Move einsum support

1303 of 1918 new or added lines in 14 files covered. (67.94%)

45 existing lines in 3 files now uncovered.

27952 of 43366 relevant lines covered (64.46%)

392.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/opt/src/transformations/einsum2dot.cpp
1
#include "sdfg/transformations/einsum2dot.h"
2

3
#include <cstddef>
4
#include <nlohmann/json_fwd.hpp>
5
#include <string>
6

7
#include "sdfg/analysis/analysis.h"
8
#include "sdfg/builder/structured_sdfg_builder.h"
9
#include "sdfg/data_flow/access_node.h"
10
#include "sdfg/data_flow/library_node.h"
11
#include "sdfg/data_flow/library_nodes/math/blas/dot_node.h"
12
#include "sdfg/data_flow/library_nodes/math/math.h"
13
#include "sdfg/einsum/einsum.h"
14
#include "sdfg/optimization_report/pass_report_consumer.h"
15
#include "sdfg/structured_control_flow/block.h"
16
#include "sdfg/symbolic/symbolic.h"
17
#include "sdfg/targets/cuda/cuda.h"
18
#include "sdfg/targets/tenstorrent/library_node_mapping.h"
19
#include "sdfg/transformations/transformation.h"
20
#include "sdfg/types/type.h"
21
#include "sdfg/types/utils.h"
22

23
namespace sdfg {
24
namespace transformations {
25

26
Einsum2Dot::Einsum2Dot(einsum::EinsumNode& einsum_node, const std::string& target_tune)
NEW
27
    : einsum_node_(einsum_node), target_tune_(target_tune) {}
×
28

NEW
29
std::string Einsum2Dot::name() const { return "Einsum2Dot"; }
×
30

NEW
31
std::optional<sdfg::data_flow::ImplementationType> Einsum2Dot::get_impl_type(types::PrimitiveType data_type) {
×
NEW
32
    std::optional<sdfg::data_flow::ImplementationType> impl_type;
×
NEW
33
    if (target_tune_ == "sequential") {
×
NEW
34
        impl_type = sdfg::data_flow::ImplementationType_NONE;
×
NEW
35
    } else if (target_tune_ == "openmp") {
×
NEW
36
        impl_type = sdfg::math::blas::ImplementationType_BLAS;
×
NEW
37
    } else if (target_tune_ == "cuda") {
×
NEW
38
        impl_type = sdfg::cuda::blas::ImplementationType_CUBLASWithTransfers;
×
NEW
39
    } else if (target_tune_ == "tenstorrent") {
×
NEW
40
        impl_type = tenstorrent::try_map_library_node_implementation(math::blas::LibraryNodeType_DOT, data_type);
×
NEW
41
    }
×
42

NEW
43
    if (impl_type) {
×
NEW
44
        return impl_type;
×
NEW
45
    } else {
×
NEW
46
        return std::nullopt;
×
NEW
47
    }
×
NEW
48
}
×
49

NEW
50
bool Einsum2Dot::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
51
    // Check dims
NEW
52
    if (this->einsum_node_.dims().size() != 1 || !symbolic::eq(this->einsum_node_.init(0), symbolic::zero())) {
×
NEW
53
        return false;
×
NEW
54
    }
×
NEW
55
    symbolic::Symbol indvar = this->einsum_node_.indvar(0);
×
56

57
    // Check out indices
NEW
58
    if (this->einsum_node_.out_indices().size() != 0) {
×
NEW
59
        return false;
×
NEW
60
    }
×
61

62
    // Check inputs
NEW
63
    if (this->einsum_node_.inputs().size() != 3 || this->einsum_node_.input(2) != this->einsum_node_.output(0)) {
×
NEW
64
        return false;
×
NEW
65
    }
×
66

67
    // Check in indices
NEW
68
    if (this->einsum_node_.in_indices(0).size() != 1 || !symbolic::eq(this->einsum_node_.in_index(0, 0), indvar)) {
×
NEW
69
        return false;
×
NEW
70
    }
×
NEW
71
    if (this->einsum_node_.in_indices(1).size() != 1 || !symbolic::eq(this->einsum_node_.in_index(1, 0), indvar)) {
×
NEW
72
        return false;
×
NEW
73
    }
×
74

75
    // Get the data flow graph
NEW
76
    auto& dfg = this->einsum_node_.get_parent();
×
77

78
    // Determine and check the base type of output
NEW
79
    auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
×
NEW
80
    auto data_type = oedge.base_type().primitive_type();
×
NEW
81
    if (data_type != types::PrimitiveType::Float && data_type != types::PrimitiveType::Double) {
×
NEW
82
        return false;
×
NEW
83
    }
×
84

85
    // Check if all inputs have the same primitive type
NEW
86
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
×
NEW
87
        if (iedge.base_type().primitive_type() != data_type) {
×
NEW
88
            return false;
×
NEW
89
        }
×
NEW
90
    }
×
91

NEW
92
    if (!get_impl_type(data_type)) { // no implementation for the given tune exists
×
NEW
93
        return false;
×
NEW
94
    }
×
95

NEW
96
    return true;
×
NEW
97
}
×
98

NEW
99
void Einsum2Dot::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
100
    // Get the data flow graph
NEW
101
    auto& dfg = this->einsum_node_.get_parent();
×
102

103
    // Get the block in which the einsum node lives
NEW
104
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
×
NEW
105
    assert(block);
×
106

107
    // Get the number of iterations (n)
NEW
108
    symbolic::Expression n = this->einsum_node_.bound(0);
×
109

110
    // Determine the BLAS precision
NEW
111
    math::blas::BLAS_Precision precision;
×
NEW
112
    auto& datatype_oedge = *dfg.out_edges(this->einsum_node_).begin();
×
NEW
113
    types::PrimitiveType data_type = datatype_oedge.base_type().primitive_type();
×
NEW
114
    if (data_type == types::PrimitiveType::Float) {
×
NEW
115
        precision = math::blas::BLAS_Precision::s;
×
NEW
116
    } else {
×
NEW
117
        precision = math::blas::BLAS_Precision::d;
×
NEW
118
    }
×
119

120
    // Add the dot node
NEW
121
    auto& libnode = builder.add_library_node<
×
NEW
122
        math::blas::DotNode,
×
NEW
123
        const data_flow::ImplementationType&,
×
NEW
124
        const math::blas::BLAS_Precision&,
×
NEW
125
        symbolic::Expression>(
×
NEW
126
        *block, this->einsum_node_.debug_info(), this->get_impl_type(data_type).value(), precision, n
×
NEW
127
    );
×
128

129
    // Copy the memlets
NEW
130
    data_flow::AccessNode* leftover_access_node = nullptr;
×
NEW
131
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
×
NEW
132
        if (iedge.dst_conn() == this->einsum_node_.input(0)) {
×
NEW
133
            builder.add_memlet(
×
NEW
134
                *block,
×
NEW
135
                iedge.src(),
×
NEW
136
                iedge.src_conn(),
×
NEW
137
                libnode,
×
NEW
138
                "__x",
×
NEW
139
                iedge.subset(),
×
NEW
140
                iedge.base_type(),
×
NEW
141
                iedge.debug_info()
×
NEW
142
            );
×
NEW
143
        } else if (iedge.dst_conn() == this->einsum_node_.input(1)) {
×
NEW
144
            builder.add_memlet(
×
NEW
145
                *block,
×
NEW
146
                iedge.src(),
×
NEW
147
                iedge.src_conn(),
×
NEW
148
                libnode,
×
NEW
149
                "__y",
×
NEW
150
                iedge.subset(),
×
NEW
151
                iedge.base_type(),
×
NEW
152
                iedge.debug_info()
×
NEW
153
            );
×
NEW
154
        } else if (iedge.dst_conn() == this->einsum_node_.input(2)) {
×
NEW
155
            leftover_access_node = dynamic_cast<data_flow::AccessNode*>(&iedge.src());
×
NEW
156
        }
×
NEW
157
    }
×
NEW
158
    for (auto& oedge : dfg.out_edges(this->einsum_node_)) {
×
NEW
159
        if (oedge.src_conn() == this->einsum_node_.output(0)) {
×
NEW
160
            builder.add_memlet(
×
NEW
161
                *block,
×
NEW
162
                libnode,
×
NEW
163
                "__out",
×
NEW
164
                oedge.dst(),
×
NEW
165
                oedge.dst_conn(),
×
NEW
166
                oedge.subset(),
×
NEW
167
                oedge.base_type(),
×
NEW
168
                oedge.debug_info()
×
NEW
169
            );
×
NEW
170
        }
×
NEW
171
    }
×
172

173
    // Remove the old memlets
NEW
174
    while (dfg.in_edges(this->einsum_node_).begin() != dfg.in_edges(this->einsum_node_).end()) {
×
NEW
175
        auto& iedge = *dfg.in_edges(this->einsum_node_).begin();
×
NEW
176
        builder.remove_memlet(*block, iedge);
×
NEW
177
    }
×
NEW
178
    while (dfg.out_edges(this->einsum_node_).begin() != dfg.out_edges(this->einsum_node_).end()) {
×
NEW
179
        auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
×
NEW
180
        builder.remove_memlet(*block, oedge);
×
NEW
181
    }
×
182

183
    // Remove leftover access node
NEW
184
    builder.remove_node(*block, *leftover_access_node);
×
185

186
    // Remove the einsum node
NEW
187
    builder.remove_node(*block, this->einsum_node_);
×
188

NEW
189
    analysis_manager.invalidate_all();
×
NEW
190
}
×
191

NEW
192
void Einsum2Dot::to_json(nlohmann::json& j) const {
×
NEW
193
    j["transformation_type"] = this->name();
×
NEW
194
    j["einsum_node_element_id"] = this->einsum_node_.element_id();
×
NEW
195
    j["target_tune"] = this->target_tune_;
×
NEW
196
}
×
197

NEW
198
Einsum2Dot Einsum2Dot::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& j) {
×
NEW
199
    assert(j.contains("einsum_node_element_id"));
×
NEW
200
    assert(j["einsum_node_element_id"].is_number_unsigned());
×
NEW
201
    assert(j.contains("impl_type"));
×
202

NEW
203
    size_t einsum_node_id = j["einsum_node_element_id"].get<size_t>();
×
NEW
204
    auto* einsum_node_element = builder.find_element_by_id(einsum_node_id);
×
NEW
205
    if (!einsum_node_element) {
×
NEW
206
        throw InvalidTransformationDescriptionException(
×
NEW
207
            "Element with ID " + std::to_string(einsum_node_id) + " not found"
×
NEW
208
        );
×
NEW
209
    }
×
NEW
210
    auto* einsum_node = dynamic_cast<einsum::EinsumNode*>(einsum_node_element);
×
NEW
211
    if (!einsum_node) {
×
NEW
212
        throw InvalidTransformationDescriptionException(
×
NEW
213
            "Element with ID " + std::to_string(einsum_node_id) + " is not an EinsumNode"
×
NEW
214
        );
×
NEW
215
    }
×
216

NEW
217
    std::string target_tune;
×
NEW
218
    if (j.contains("target_tune")) {
×
NEW
219
        target_tune = j.at("target_tune").get<std::string>();
×
NEW
220
    } else {
×
NEW
221
        target_tune = "none";
×
NEW
222
    }
×
223

NEW
224
    return Einsum2Dot(*einsum_node, target_tune);
×
NEW
225
}
×
226

227
} // namespace transformations
228
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc