• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 26352712167

22 May 2026 02:58PM UTC coverage: 60.848% (+0.01%) from 60.837%
26352712167

push

github

web-flow
Activate Einsum pipeline for Python/PyTorch frontend (#720)

- Added Einsum step to compilation (after which the SDFG is also dumped)
- Generate tasklets and cmath library nodes directly instead of tensor nodes if all tensor types are scalar
- Merged the whole "lifting" part of Einsum nodes into one pass: the EinsumDetectionPass that runs in linear time
- Renamed the EinsumExpand transformation to the EinsumPromotion transformation because the name was confusing with expanding Einsum/library nodes
- Fixed a bug where the lifting failed for multiple input edges of the same access node on a tasklet and added unit tests
- Moved Einsum node source files into tensor folder (but the Einsum node is not a tensor node yet)

85 of 115 new or added lines in 13 files covered. (73.91%)

4 existing lines in 2 files now uncovered.

35024 of 57560 relevant lines covered (60.85%)

11105.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.07
/opt/src/transformations/einsum2gemm.cpp
1
#include "sdfg/transformations/einsum2gemm.h"
2

3
#include <algorithm>
4
#include <cassert>
5
#include <cstddef>
6
#include <nlohmann/json_fwd.hpp>
7
#include <optional>
8
#include <string>
9
#include <vector>
10

11
#include "sdfg/analysis/analysis.h"
12
#include "sdfg/builder/structured_sdfg_builder.h"
13
#include "sdfg/data_flow/library_node.h"
14
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
15
#include "sdfg/data_flow/library_nodes/math/math.h"
16
#include "sdfg/data_flow/library_nodes/math/tensor/einsum_node.h"
17
#include "sdfg/symbolic/symbolic.h"
18
#include "sdfg/transformations/transformation.h"
19
#include "sdfg/types/scalar.h"
20
#include "sdfg/types/type.h"
21
#include "symengine/symengine_rcp.h"
22

23
namespace sdfg {
24
namespace transformations {
25

26
bool Einsum2Gemm::check_matrix_indices(long long mat, const symbolic::Symbol& indvar1, const symbolic::Symbol& indvar2) {
4✔
27
    return !symbolic::eq(this->einsum_node_.in_index(mat, 0), this->einsum_node_.in_index(mat, 1)) &&
4✔
28
           (symbolic::eq(this->einsum_node_.in_index(mat, 0), indvar1) ||
4✔
29
            symbolic::eq(this->einsum_node_.in_index(mat, 0), indvar2)) &&
4✔
30
           (symbolic::eq(this->einsum_node_.in_index(mat, 1), indvar1) ||
4✔
31
            symbolic::eq(this->einsum_node_.in_index(mat, 1), indvar2));
4✔
32
}
4✔
33

34
Einsum2Gemm::Einsum2Gemm(math::tensor::EinsumNode& einsum_node) : einsum_node_(einsum_node) {}
2✔
35

36
std::string Einsum2Gemm::name() const { return "Einsum2Gemm"; }
×
37

38
bool Einsum2Gemm::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
39
    // Check dims
40
    if (this->einsum_node_.dims().size() != 3) {
2✔
41
        return false;
×
42
    }
×
43

44
    // Check initial values
45
    for (size_t i = 0; i < 3; i++) {
8✔
46
        if (!symbolic::eq(this->einsum_node_.init(i), symbolic::zero())) {
6✔
47
            return false;
×
48
        }
×
49
    }
6✔
50

51
    // Check out indices
52
    if (this->einsum_node_.out_indices().size() != 2) {
2✔
53
        return false;
×
54
    }
×
55
    symbolic::Symbol indvar_outer_1 = SymEngine::null, indvar_outer_2 = SymEngine::null, indvar_inner = SymEngine::null;
2✔
56
    std::vector<size_t> permutation = {0, 1, 2};
2✔
57
    do {
2✔
58
        if (symbolic::eq(this->einsum_node_.out_index(0), this->einsum_node_.indvar(permutation[0])) &&
2✔
59
            symbolic::eq(this->einsum_node_.out_index(1), this->einsum_node_.indvar(permutation[1]))) {
2✔
60
            indvar_outer_1 = this->einsum_node_.indvar(permutation[0]);
2✔
61
            indvar_outer_2 = this->einsum_node_.indvar(permutation[1]);
2✔
62
            indvar_inner = this->einsum_node_.indvar(permutation[2]);
2✔
63
            break;
2✔
64
        }
2✔
65
    } while (std::next_permutation(permutation.begin(), permutation.end()));
2✔
66
    if (indvar_outer_1.is_null() || indvar_outer_2.is_null() || indvar_inner.is_null()) {
2✔
67
        return false;
×
68
    }
×
69

70
    // Check bounds, i.e., preven triangular access
71
    for (size_t i = 0; i < 3; i++) {
8✔
72
        if (symbolic::uses(this->einsum_node_.bound(i), indvar_outer_1) ||
6✔
73
            symbolic::uses(this->einsum_node_.bound(i), indvar_outer_2) ||
6✔
74
            symbolic::uses(this->einsum_node_.bound(i), indvar_inner)) {
6✔
75
            return false;
×
76
        }
×
77
    }
6✔
78

79
    // Check inputs
80
    long long A = -1, B = -1, C = -1;
2✔
81
    if (this->einsum_node_.inputs().size() == 3) {
2✔
82
        C = 2;
1✔
83
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
3✔
84
            if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
85
                break;
×
86
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
87
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
88
                A = i;
1✔
89
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
90
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
91
                B = i;
1✔
92
            }
1✔
93
        }
2✔
94
    } else if (this->einsum_node_.inputs().size() == 4) {
1✔
95
        C = 3;
1✔
96
        long long alpha = -1;
1✔
97
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
4✔
98
            if (this->einsum_node_.in_indices(i).size() == 0) {
3✔
99
                alpha = i;
1✔
100
            } else if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
101
                break;
×
102
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
103
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
104
                A = i;
1✔
105
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
106
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
107
                B = i;
1✔
108
            }
1✔
109
        }
3✔
110

111
        // Check alpha
112
        if (alpha == -1 || this->einsum_node_.in_indices(alpha).size() != 0) {
1✔
113
            return false;
×
114
        }
×
115
    } else {
1✔
116
        return false;
×
117
    }
×
118
    if (A == -1 || B == -1 || A == B || this->einsum_node_.input(C) != this->einsum_node_.output(0)) {
2✔
119
        return false;
×
120
    }
×
121

122
    // Check in indices
123
    if (this->einsum_node_.in_indices(A).size() != 2 || !this->check_matrix_indices(A, indvar_outer_1, indvar_inner)) {
2✔
124
        return false;
×
125
    }
×
126
    if (this->einsum_node_.in_indices(B).size() != 2 || !this->check_matrix_indices(B, indvar_inner, indvar_outer_2)) {
2✔
127
        return false;
×
128
    }
×
129
    if (this->einsum_node_.in_indices(C).size() != 2 ||
2✔
130
        !symbolic::eq(this->einsum_node_.in_index(C, 0), indvar_outer_1) ||
2✔
131
        !symbolic::eq(this->einsum_node_.in_index(C, 1), indvar_outer_2)) {
2✔
132
        return false;
×
133
    }
×
134

135
    // Get the data flow graph
136
    auto& dfg = this->einsum_node_.get_parent();
2✔
137

138
    // Determine and check the base type of output
139
    auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
140
    auto data_type = oedge.base_type().primitive_type();
2✔
141
    if (data_type != types::PrimitiveType::Float && data_type != types::PrimitiveType::Double) {
2✔
142
        return false;
×
143
    }
×
144

145
    // Check if all inputs have the same primitive type
146
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
7✔
147
        if (iedge.base_type().primitive_type() != data_type) {
7✔
148
            return false;
×
149
        }
×
150
    }
7✔
151

152
    return true;
2✔
153
}
2✔
154

155
void Einsum2Gemm::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
156
    // Get the data flow graph
157
    auto& dfg = this->einsum_node_.get_parent();
2✔
158

159
    // Get the block in which the einsum node lives
160
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
2✔
161
    assert(block);
2✔
162

163
    // Determine the BLAS precision
164
    math::blas::BLAS_Precision precision;
2✔
165
    auto& datatype_oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
166
    types::PrimitiveType data_type = datatype_oedge.base_type().primitive_type();
2✔
167
    if (data_type == types::PrimitiveType::Float) {
2✔
168
        precision = math::blas::BLAS_Precision::s;
2✔
169
    } else {
2✔
170
        precision = math::blas::BLAS_Precision::d;
×
171
    }
×
172

173
    // Determine indvars
174
    symbolic::Symbol indvar_outer_1 = SymEngine::null, indvar_outer_2 = SymEngine::null, indvar_inner = SymEngine::null;
2✔
175
    symbolic::Expression m = SymEngine::null, n = SymEngine::null, k = SymEngine::null;
2✔
176
    std::vector<size_t> permutation = {0, 1, 2};
2✔
177
    do {
2✔
178
        if (symbolic::eq(this->einsum_node_.out_index(0), this->einsum_node_.indvar(permutation[0])) &&
2✔
179
            symbolic::eq(this->einsum_node_.out_index(1), this->einsum_node_.indvar(permutation[1]))) {
2✔
180
            indvar_outer_1 = this->einsum_node_.indvar(permutation[0]);
2✔
181
            indvar_outer_2 = this->einsum_node_.indvar(permutation[1]);
2✔
182
            indvar_inner = this->einsum_node_.indvar(permutation[2]);
2✔
183
            m = this->einsum_node_.bound(permutation[0]);
2✔
184
            n = this->einsum_node_.bound(permutation[1]);
2✔
185
            k = this->einsum_node_.bound(permutation[2]);
2✔
186
            break;
2✔
187
        }
2✔
188
    } while (std::next_permutation(permutation.begin(), permutation.end()));
2✔
189
    assert(
2✔
190
        !indvar_outer_1.is_null() && !indvar_outer_2.is_null() && !indvar_inner.is_null() && !m.is_null() &&
2✔
191
        !n.is_null() && !k.is_null()
2✔
192
    );
2✔
193

194
    // Determine inputs
195
    long long alpha = -1, A = -1, B = -1, C = -1;
2✔
196
    bool has_alpha = false;
2✔
197
    if (this->einsum_node_.inputs().size() == 3) {
2✔
198
        C = 2;
1✔
199
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
3✔
200
            if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
201
                break;
×
202
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
203
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
204
                A = i;
1✔
205
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
206
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
207
                B = i;
1✔
208
            }
1✔
209
        }
2✔
210
    } else if (this->einsum_node_.inputs().size() == 4) {
1✔
211
        C = 3;
1✔
212
        has_alpha = true;
1✔
213
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
4✔
214
            if (this->einsum_node_.in_indices(i).size() == 0) {
3✔
215
                alpha = i;
1✔
216
            } else if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
217
                break;
×
218
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
219
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
220
                A = i;
1✔
221
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
222
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
223
                B = i;
1✔
224
            }
1✔
225
        }
3✔
226
    }
1✔
227

228
    // Determine transpose and leading dimensions
229
    math::blas::BLAS_Transpose transA, transB;
2✔
230
    symbolic::Expression ldA, ldB, ldC;
2✔
231
    if (symbolic::eq(this->einsum_node_.in_index(A, 0), indvar_outer_1)) {
2✔
232
        transA = math::blas::BLAS_Transpose::No;
2✔
233
        ldA = k;
2✔
234
    } else {
2✔
235
        transA = math::blas::BLAS_Transpose::Trans;
×
236
        ldA = m;
×
237
    }
×
238
    if (symbolic::eq(this->einsum_node_.in_index(B, 1), indvar_outer_2)) {
2✔
239
        transB = math::blas::BLAS_Transpose::No;
2✔
240
        ldB = n;
2✔
241
    } else {
2✔
242
        transB = math::blas::BLAS_Transpose::Trans;
×
243
        ldB = k;
×
244
    }
×
245
    ldC = n;
2✔
246

247
    // Add the BLAS node for gemm
248
    auto& libnode = builder.add_library_node<math::blas::GEMMNode>(
2✔
249
        *block,
2✔
250
        this->einsum_node_.debug_info(),
2✔
251
        sdfg::math::blas::ImplementationType_BLAS,
2✔
252
        precision,
2✔
253
        math::blas::BLAS_Layout::RowMajor,
2✔
254
        transA,
2✔
255
        transB,
2✔
256
        m,
2✔
257
        n,
2✔
258
        k,
2✔
259
        ldA,
2✔
260
        ldB,
2✔
261
        ldC
2✔
262
    );
2✔
263

264
    // Copy the memlets
265
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
7✔
266
        if (has_alpha && iedge.dst_conn() == this->einsum_node_.input(alpha)) {
7✔
267
            builder.add_memlet(
1✔
268
                *block,
1✔
269
                iedge.src(),
1✔
270
                iedge.src_conn(),
1✔
271
                libnode,
1✔
272
                "__alpha",
1✔
273
                iedge.subset(),
1✔
274
                iedge.base_type(),
1✔
275
                iedge.debug_info()
1✔
276
            );
1✔
277
        } else if (iedge.dst_conn() == this->einsum_node_.input(A)) {
6✔
278
            builder.add_memlet(
2✔
279
                *block,
2✔
280
                iedge.src(),
2✔
281
                iedge.src_conn(),
2✔
282
                libnode,
2✔
283
                "__A",
2✔
284
                iedge.subset(),
2✔
285
                iedge.base_type(),
2✔
286
                iedge.debug_info()
2✔
287
            );
2✔
288
        } else if (iedge.dst_conn() == this->einsum_node_.input(B)) {
4✔
289
            builder.add_memlet(
2✔
290
                *block,
2✔
291
                iedge.src(),
2✔
292
                iedge.src_conn(),
2✔
293
                libnode,
2✔
294
                "__B",
2✔
295
                iedge.subset(),
2✔
296
                iedge.base_type(),
2✔
297
                iedge.debug_info()
2✔
298
            );
2✔
299
        } else if (iedge.dst_conn() == this->einsum_node_.input(C)) {
2✔
300
            builder.add_memlet(
2✔
301
                *block,
2✔
302
                iedge.src(),
2✔
303
                iedge.src_conn(),
2✔
304
                libnode,
2✔
305
                "__C",
2✔
306
                iedge.subset(),
2✔
307
                iedge.base_type(),
2✔
308
                iedge.debug_info()
2✔
309
            );
2✔
310
        }
2✔
311
    }
7✔
312
    for (auto& oedge : dfg.out_edges(this->einsum_node_)) {
2✔
313
        if (oedge.src_conn() == this->einsum_node_.output(0)) {
2✔
314
            builder.add_memlet(
2✔
315
                *block,
2✔
316
                libnode,
2✔
317
                "__C",
2✔
318
                oedge.dst(),
2✔
319
                oedge.dst_conn(),
2✔
320
                oedge.subset(),
2✔
321
                oedge.base_type(),
2✔
322
                oedge.debug_info()
2✔
323
            );
2✔
324
        }
2✔
325
    }
2✔
326

327
    // Remove the old memlets
328
    while (dfg.in_edges(this->einsum_node_).begin() != dfg.in_edges(this->einsum_node_).end()) {
9✔
329
        auto& iedge = *dfg.in_edges(this->einsum_node_).begin();
7✔
330
        builder.remove_memlet(*block, iedge);
7✔
331
    }
7✔
332
    while (dfg.out_edges(this->einsum_node_).begin() != dfg.out_edges(this->einsum_node_).end()) {
4✔
333
        auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
334
        builder.remove_memlet(*block, oedge);
2✔
335
    }
2✔
336

337
    // Add constant scalars alpha and beta (if needed)
338
    types::Scalar data_type_scalar(data_type);
2✔
339
    if (!has_alpha) {
2✔
340
        auto& alpha_access_node =
1✔
341
            builder.add_constant(*block, "1.0", data_type_scalar, this->einsum_node_.debug_info());
1✔
342
        builder.add_memlet(
1✔
343
            *block, alpha_access_node, "void", libnode, "__alpha", {}, data_type_scalar, this->einsum_node_.debug_info()
1✔
344
        );
1✔
345
    }
1✔
346
    auto& beta_access_node = builder.add_constant(*block, "1.0", data_type_scalar, this->einsum_node_.debug_info());
2✔
347
    builder.add_memlet(
2✔
348
        *block, beta_access_node, "void", libnode, "__beta", {}, data_type_scalar, this->einsum_node_.debug_info()
2✔
349
    );
2✔
350

351
    // Remove the einsum node
352
    builder.remove_node(*block, this->einsum_node_);
2✔
353

354
    analysis_manager.invalidate_all();
2✔
355
}
2✔
356

357
void Einsum2Gemm::to_json(nlohmann::json& j) const {
×
358
    j["transformation_type"] = this->name();
×
359
    j["einsum_node_element_id"] = this->einsum_node_.element_id();
×
360
}
×
361

362
Einsum2Gemm Einsum2Gemm::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& j) {
×
363
    assert(j.contains("einsum_node_element_id"));
×
364
    assert(j["einsum_node_element_id"].is_number_unsigned());
×
365

366
    size_t einsum_node_id = j["einsum_node_element_id"].get<size_t>();
×
367
    auto* einsum_node_element = builder.find_element_by_id(einsum_node_id);
×
368
    if (!einsum_node_element) {
×
369
        throw InvalidTransformationDescriptionException(
×
370
            "Element with ID " + std::to_string(einsum_node_id) + " not found"
×
371
        );
×
372
    }
×
NEW
373
    auto* einsum_node = dynamic_cast<math::tensor::EinsumNode*>(einsum_node_element);
×
374
    if (!einsum_node) {
×
375
        throw InvalidTransformationDescriptionException(
×
376
            "Element with ID " + std::to_string(einsum_node_id) + " is not an EinsumNode"
×
377
        );
×
378
    }
×
379

380
    return Einsum2Gemm(*einsum_node);
×
381
}
×
382

383
} // namespace transformations
384
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc