• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23484120816

24 Mar 2026 10:12AM UTC coverage: 64.587% (+0.3%) from 64.295%
23484120816

Pull #605

github

web-flow
Merge 03f1c151d into 0428b1aa7
Pull Request #605: Move einsum support

1307 of 1836 new or added lines in 10 files covered. (71.19%)

45 existing lines in 3 files now uncovered.

27956 of 43284 relevant lines covered (64.59%)

393.57 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/opt/src/transformations/einsum2dot.cpp
1
#include "sdfg/transformations/einsum2dot.h"
2

3
#include <cstddef>
4
#include <nlohmann/json_fwd.hpp>
5
#include <string>
6

7
#include "sdfg/analysis/analysis.h"
8
#include "sdfg/builder/structured_sdfg_builder.h"
9
#include "sdfg/data_flow/access_node.h"
10
#include "sdfg/data_flow/library_node.h"
11
#include "sdfg/data_flow/library_nodes/math/blas/dot_node.h"
12
#include "sdfg/data_flow/library_nodes/math/math.h"
13
#include "sdfg/einsum/einsum.h"
14
#include "sdfg/optimization_report/pass_report_consumer.h"
15
#include "sdfg/structured_control_flow/block.h"
16
#include "sdfg/symbolic/symbolic.h"
17
#include "sdfg/targets/cuda/cuda.h"
18
#include "sdfg/transformations/transformation.h"
19
#include "sdfg/types/type.h"
20
#include "sdfg/types/utils.h"
21

22
namespace sdfg {
23
namespace transformations {
24

25
Einsum2Dot::Einsum2Dot(einsum::EinsumNode& einsum_node, const std::string& target_tune)
NEW
26
    : einsum_node_(einsum_node), target_tune_(target_tune) {}
×
27

NEW
28
std::string Einsum2Dot::name() const { return "Einsum2Dot"; }
×
29

NEW
30
std::optional<sdfg::data_flow::ImplementationType> Einsum2Dot::get_impl_type(types::PrimitiveType data_type) {
×
NEW
31
    std::optional<sdfg::data_flow::ImplementationType> impl_type;
×
NEW
32
    if (target_tune_ == "sequential") {
×
NEW
33
        impl_type = sdfg::data_flow::ImplementationType_NONE;
×
NEW
34
    } else if (target_tune_ == "openmp") {
×
NEW
35
        impl_type = sdfg::math::blas::ImplementationType_BLAS;
×
NEW
36
    } else if (target_tune_ == "cuda") {
×
NEW
37
        impl_type = sdfg::cuda::blas::ImplementationType_CUBLASWithTransfers;
×
NEW
38
    } else if (target_tune_ == "tenstorrent") {
×
NEW
39
        if (data_type == types::PrimitiveType::Float) {
×
NEW
40
            impl_type = data_flow::ImplementationType{"TENSTORRENT_WithTransfers"};
×
NEW
41
        }
×
NEW
42
    }
×
43

NEW
44
    if (impl_type) {
×
NEW
45
        return impl_type;
×
NEW
46
    } else {
×
NEW
47
        return std::nullopt;
×
NEW
48
    }
×
NEW
49
}
×
50

NEW
51
bool Einsum2Dot::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
52
    // Check dims
NEW
53
    if (this->einsum_node_.dims().size() != 1 || !symbolic::eq(this->einsum_node_.init(0), symbolic::zero())) {
×
NEW
54
        return false;
×
NEW
55
    }
×
NEW
56
    symbolic::Symbol indvar = this->einsum_node_.indvar(0);
×
57

58
    // Check out indices
NEW
59
    if (this->einsum_node_.out_indices().size() != 0) {
×
NEW
60
        return false;
×
NEW
61
    }
×
62

63
    // Check inputs
NEW
64
    if (this->einsum_node_.inputs().size() != 3 || this->einsum_node_.input(2) != this->einsum_node_.output(0)) {
×
NEW
65
        return false;
×
NEW
66
    }
×
67

68
    // Check in indices
NEW
69
    if (this->einsum_node_.in_indices(0).size() != 1 || !symbolic::eq(this->einsum_node_.in_index(0, 0), indvar)) {
×
NEW
70
        return false;
×
NEW
71
    }
×
NEW
72
    if (this->einsum_node_.in_indices(1).size() != 1 || !symbolic::eq(this->einsum_node_.in_index(1, 0), indvar)) {
×
NEW
73
        return false;
×
NEW
74
    }
×
75

76
    // Get the data flow graph
NEW
77
    auto& dfg = this->einsum_node_.get_parent();
×
78

79
    // Determine and check the base type of output
NEW
80
    auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
×
NEW
81
    auto data_type = oedge.base_type().primitive_type();
×
NEW
82
    if (data_type != types::PrimitiveType::Float && data_type != types::PrimitiveType::Double) {
×
NEW
83
        return false;
×
NEW
84
    }
×
85

86
    // Check if all inputs have the same primitive type
NEW
87
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
×
NEW
88
        if (iedge.base_type().primitive_type() != data_type) {
×
NEW
89
            return false;
×
NEW
90
        }
×
NEW
91
    }
×
92

NEW
93
    if (!get_impl_type(data_type)) { // no implementation for the given tune exists
×
NEW
94
        return false;
×
NEW
95
    }
×
96

NEW
97
    return true;
×
NEW
98
}
×
99

NEW
100
void Einsum2Dot::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
101
    // Get the data flow graph
NEW
102
    auto& dfg = this->einsum_node_.get_parent();
×
103

104
    // Get the block in which the einsum node lives
NEW
105
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
×
NEW
106
    assert(block);
×
107

108
    // Get the number of iterations (n)
NEW
109
    symbolic::Expression n = this->einsum_node_.bound(0);
×
110

111
    // Determine the BLAS precision
NEW
112
    math::blas::BLAS_Precision precision;
×
NEW
113
    auto& datatype_oedge = *dfg.out_edges(this->einsum_node_).begin();
×
NEW
114
    types::PrimitiveType data_type = datatype_oedge.base_type().primitive_type();
×
NEW
115
    if (data_type == types::PrimitiveType::Float) {
×
NEW
116
        precision = math::blas::BLAS_Precision::s;
×
NEW
117
    } else {
×
NEW
118
        precision = math::blas::BLAS_Precision::d;
×
NEW
119
    }
×
120

121
    // Add the dot node
NEW
122
    auto& libnode = builder.add_library_node<
×
NEW
123
        math::blas::DotNode,
×
NEW
124
        const data_flow::ImplementationType&,
×
NEW
125
        const math::blas::BLAS_Precision&,
×
NEW
126
        symbolic::Expression>(
×
NEW
127
        *block, this->einsum_node_.debug_info(), this->get_impl_type(data_type).value(), precision, n
×
NEW
128
    );
×
129

130
    // Copy the memlets
NEW
131
    data_flow::AccessNode* leftover_access_node = nullptr;
×
NEW
132
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
×
NEW
133
        if (iedge.dst_conn() == this->einsum_node_.input(0)) {
×
NEW
134
            builder.add_memlet(
×
NEW
135
                *block,
×
NEW
136
                iedge.src(),
×
NEW
137
                iedge.src_conn(),
×
NEW
138
                libnode,
×
NEW
139
                "__x",
×
NEW
140
                iedge.subset(),
×
NEW
141
                iedge.base_type(),
×
NEW
142
                iedge.debug_info()
×
NEW
143
            );
×
NEW
144
        } else if (iedge.dst_conn() == this->einsum_node_.input(1)) {
×
NEW
145
            builder.add_memlet(
×
NEW
146
                *block,
×
NEW
147
                iedge.src(),
×
NEW
148
                iedge.src_conn(),
×
NEW
149
                libnode,
×
NEW
150
                "__y",
×
NEW
151
                iedge.subset(),
×
NEW
152
                iedge.base_type(),
×
NEW
153
                iedge.debug_info()
×
NEW
154
            );
×
NEW
155
        } else if (iedge.dst_conn() == this->einsum_node_.input(2)) {
×
NEW
156
            leftover_access_node = dynamic_cast<data_flow::AccessNode*>(&iedge.src());
×
NEW
157
        }
×
NEW
158
    }
×
NEW
159
    for (auto& oedge : dfg.out_edges(this->einsum_node_)) {
×
NEW
160
        if (oedge.src_conn() == this->einsum_node_.output(0)) {
×
NEW
161
            builder.add_memlet(
×
NEW
162
                *block,
×
NEW
163
                libnode,
×
NEW
164
                "__out",
×
NEW
165
                oedge.dst(),
×
NEW
166
                oedge.dst_conn(),
×
NEW
167
                oedge.subset(),
×
NEW
168
                oedge.base_type(),
×
NEW
169
                oedge.debug_info()
×
NEW
170
            );
×
NEW
171
        }
×
NEW
172
    }
×
173

174
    // Remove the old memlets
NEW
175
    while (dfg.in_edges(this->einsum_node_).begin() != dfg.in_edges(this->einsum_node_).end()) {
×
NEW
176
        auto& iedge = *dfg.in_edges(this->einsum_node_).begin();
×
NEW
177
        builder.remove_memlet(*block, iedge);
×
NEW
178
    }
×
NEW
179
    while (dfg.out_edges(this->einsum_node_).begin() != dfg.out_edges(this->einsum_node_).end()) {
×
NEW
180
        auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
×
NEW
181
        builder.remove_memlet(*block, oedge);
×
NEW
182
    }
×
183

184
    // Remove leftover access node
NEW
185
    builder.remove_node(*block, *leftover_access_node);
×
186

187
    // Remove the einsum node
NEW
188
    builder.remove_node(*block, this->einsum_node_);
×
189

NEW
190
    analysis_manager.invalidate_all();
×
NEW
191
}
×
192

NEW
193
void Einsum2Dot::to_json(nlohmann::json& j) const {
×
NEW
194
    j["transformation_type"] = this->name();
×
NEW
195
    j["einsum_node_element_id"] = this->einsum_node_.element_id();
×
NEW
196
    j["target_tune"] = this->target_tune_;
×
NEW
197
}
×
198

NEW
199
Einsum2Dot Einsum2Dot::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& j) {
×
NEW
200
    assert(j.contains("einsum_node_element_id"));
×
NEW
201
    assert(j["einsum_node_element_id"].is_number_unsigned());
×
NEW
202
    assert(j.contains("impl_type"));
×
203

NEW
204
    size_t einsum_node_id = j["einsum_node_element_id"].get<size_t>();
×
NEW
205
    auto* einsum_node_element = builder.find_element_by_id(einsum_node_id);
×
NEW
206
    if (!einsum_node_element) {
×
NEW
207
        throw InvalidTransformationDescriptionException(
×
NEW
208
            "Element with ID " + std::to_string(einsum_node_id) + " not found"
×
NEW
209
        );
×
NEW
210
    }
×
NEW
211
    auto* einsum_node = dynamic_cast<einsum::EinsumNode*>(einsum_node_element);
×
NEW
212
    if (!einsum_node) {
×
NEW
213
        throw InvalidTransformationDescriptionException(
×
NEW
214
            "Element with ID " + std::to_string(einsum_node_id) + " is not an EinsumNode"
×
NEW
215
        );
×
NEW
216
    }
×
217

NEW
218
    std::string target_tune;
×
NEW
219
    if (j.contains("target_tune")) {
×
NEW
220
        target_tune = j.at("target_tune").get<std::string>();
×
NEW
221
    } else {
×
NEW
222
        target_tune = "none";
×
NEW
223
    }
×
224

NEW
225
    return Einsum2Dot(*einsum_node, target_tune);
×
NEW
226
}
×
227

228
} // namespace transformations
229
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc