• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23484120816

24 Mar 2026 10:12AM UTC coverage: 64.587% (+0.3%) from 64.295%
23484120816

Pull #605

github

web-flow
Merge 03f1c151d into 0428b1aa7
Pull Request #605: Move einsum support

1307 of 1836 new or added lines in 10 files covered. (71.19%)

45 existing lines in 3 files now uncovered.

27956 of 43284 relevant lines covered (64.59%)

393.57 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.13
/opt/src/transformations/einsum2gemm.cpp
1
#include "sdfg/transformations/einsum2gemm.h"
2

3
#include <algorithm>
4
#include <cassert>
5
#include <cstddef>
6
#include <nlohmann/json_fwd.hpp>
7
#include <optional>
8
#include <string>
9
#include <vector>
10

11
#include "sdfg/analysis/analysis.h"
12
#include "sdfg/builder/structured_sdfg_builder.h"
13
#include "sdfg/data_flow/library_node.h"
14
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
15
#include "sdfg/data_flow/library_nodes/math/math.h"
16
#include "sdfg/einsum/einsum.h"
17
#include "sdfg/symbolic/symbolic.h"
18
#include "sdfg/transformations/transformation.h"
19
#include "sdfg/types/scalar.h"
20
#include "sdfg/types/type.h"
21
#include "symengine/symengine_rcp.h"
22

23
namespace sdfg {
24
namespace transformations {
25

26
bool Einsum2Gemm::check_matrix_indices(long long mat, const symbolic::Symbol& indvar1, const symbolic::Symbol& indvar2) {
4✔
27
    return !symbolic::eq(this->einsum_node_.in_index(mat, 0), this->einsum_node_.in_index(mat, 1)) &&
4✔
28
           (symbolic::eq(this->einsum_node_.in_index(mat, 0), indvar1) ||
4✔
29
            symbolic::eq(this->einsum_node_.in_index(mat, 0), indvar2)) &&
4✔
30
           (symbolic::eq(this->einsum_node_.in_index(mat, 1), indvar1) ||
4✔
31
            symbolic::eq(this->einsum_node_.in_index(mat, 1), indvar2));
4✔
32
}
4✔
33

34
Einsum2Gemm::Einsum2Gemm(einsum::EinsumNode& einsum_node, const std::string& target_tune)
35
    : einsum_node_(einsum_node), target_tune_(target_tune) {}
2✔
36

NEW
37
std::string Einsum2Gemm::name() const { return "Einsum2Gemm"; }
×
38

39
std::optional<sdfg::data_flow::ImplementationType> Einsum2Gemm::get_impl_type(types::PrimitiveType data_type) {
4✔
40
    std::optional<sdfg::data_flow::ImplementationType> impl_type = std::nullopt;
4✔
41
    if (this->target_tune_ == "openmp") {
4✔
42
        impl_type = std::make_optional(sdfg::math::blas::ImplementationType_BLAS);
4✔
43
    } else if (this->target_tune_ == "tenstorrent") {
4✔
NEW
44
        if (data_type == types::PrimitiveType::Float) {
×
NEW
45
            impl_type = data_flow::ImplementationType{"TENSTORRENT_WithTransfers"};
×
NEW
46
        }
×
NEW
47
    }
×
48
    // TODO: Implement GEMM dispatcher for CUBLAS
49

50
    return impl_type;
4✔
51
}
4✔
52

53
bool Einsum2Gemm::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
54
    // Check dims
55
    if (this->einsum_node_.dims().size() != 3) {
2✔
NEW
56
        return false;
×
NEW
57
    }
×
58

59
    // Check initial values
60
    for (size_t i = 0; i < 3; i++) {
8✔
61
        if (!symbolic::eq(this->einsum_node_.init(i), symbolic::zero())) {
6✔
NEW
62
            return false;
×
NEW
63
        }
×
64
    }
6✔
65

66
    // Check out indices
67
    if (this->einsum_node_.out_indices().size() != 2) {
2✔
NEW
68
        return false;
×
NEW
69
    }
×
70
    symbolic::Symbol indvar_outer_1 = SymEngine::null, indvar_outer_2 = SymEngine::null, indvar_inner = SymEngine::null;
2✔
71
    std::vector<size_t> permutation = {0, 1, 2};
2✔
72
    do {
2✔
73
        if (symbolic::eq(this->einsum_node_.out_index(0), this->einsum_node_.indvar(permutation[0])) &&
2✔
74
            symbolic::eq(this->einsum_node_.out_index(1), this->einsum_node_.indvar(permutation[1]))) {
2✔
75
            indvar_outer_1 = this->einsum_node_.indvar(permutation[0]);
2✔
76
            indvar_outer_2 = this->einsum_node_.indvar(permutation[1]);
2✔
77
            indvar_inner = this->einsum_node_.indvar(permutation[2]);
2✔
78
            break;
2✔
79
        }
2✔
80
    } while (std::next_permutation(permutation.begin(), permutation.end()));
2✔
81
    if (indvar_outer_1.is_null() || indvar_outer_2.is_null() || indvar_inner.is_null()) {
2✔
NEW
82
        return false;
×
NEW
83
    }
×
84

85
    // Check bounds, i.e., preven triangular access
86
    for (size_t i = 0; i < 3; i++) {
8✔
87
        if (symbolic::uses(this->einsum_node_.bound(i), indvar_outer_1) ||
6✔
88
            symbolic::uses(this->einsum_node_.bound(i), indvar_outer_2) ||
6✔
89
            symbolic::uses(this->einsum_node_.bound(i), indvar_inner)) {
6✔
NEW
90
            return false;
×
NEW
91
        }
×
92
    }
6✔
93

94
    // Check inputs
95
    long long A = -1, B = -1, C = -1;
2✔
96
    if (this->einsum_node_.inputs().size() == 3) {
2✔
97
        C = 2;
1✔
98
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
3✔
99
            if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
NEW
100
                break;
×
101
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
102
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
103
                A = i;
1✔
104
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
105
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
106
                B = i;
1✔
107
            }
1✔
108
        }
2✔
109
    } else if (this->einsum_node_.inputs().size() == 4) {
1✔
110
        C = 3;
1✔
111
        long long alpha = -1;
1✔
112
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
4✔
113
            if (this->einsum_node_.in_indices(i).size() == 0) {
3✔
114
                alpha = i;
1✔
115
            } else if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
NEW
116
                break;
×
117
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
118
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
119
                A = i;
1✔
120
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
121
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
122
                B = i;
1✔
123
            }
1✔
124
        }
3✔
125

126
        // Check alpha
127
        if (alpha == -1 || this->einsum_node_.in_indices(alpha).size() != 0) {
1✔
NEW
128
            return false;
×
NEW
129
        }
×
130
    } else {
1✔
NEW
131
        return false;
×
NEW
132
    }
×
133
    if (A == -1 || B == -1 || A == B || this->einsum_node_.input(C) != this->einsum_node_.output(0)) {
2✔
NEW
134
        return false;
×
NEW
135
    }
×
136

137
    // Check in indices
138
    if (this->einsum_node_.in_indices(A).size() != 2 || !this->check_matrix_indices(A, indvar_outer_1, indvar_inner)) {
2✔
NEW
139
        return false;
×
NEW
140
    }
×
141
    if (this->einsum_node_.in_indices(B).size() != 2 || !this->check_matrix_indices(B, indvar_inner, indvar_outer_2)) {
2✔
NEW
142
        return false;
×
NEW
143
    }
×
144
    if (this->einsum_node_.in_indices(C).size() != 2 ||
2✔
145
        !symbolic::eq(this->einsum_node_.in_index(C, 0), indvar_outer_1) ||
2✔
146
        !symbolic::eq(this->einsum_node_.in_index(C, 1), indvar_outer_2)) {
2✔
NEW
147
        return false;
×
NEW
148
    }
×
149

150
    // Get the data flow graph
151
    auto& dfg = this->einsum_node_.get_parent();
2✔
152

153
    // Determine and check the base type of output
154
    auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
155
    auto data_type = oedge.base_type().primitive_type();
2✔
156
    if (data_type != types::PrimitiveType::Float && data_type != types::PrimitiveType::Double) {
2✔
NEW
157
        return false;
×
NEW
158
    }
×
159

160
    // Check if all inputs have the same primitive type
161
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
7✔
162
        if (iedge.base_type().primitive_type() != data_type) {
7✔
NEW
163
            return false;
×
NEW
164
        }
×
165
    }
7✔
166

167
    if (!this->get_impl_type(data_type)) { // no implementation for the given tune exists
2✔
NEW
168
        return false;
×
NEW
169
    }
×
170

171
    return true;
2✔
172
}
2✔
173

174
void Einsum2Gemm::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
175
    // Get the data flow graph
176
    auto& dfg = this->einsum_node_.get_parent();
2✔
177

178
    // Get the block in which the einsum node lives
179
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
2✔
180
    assert(block);
2✔
181

182
    // Determine the BLAS precision
183
    math::blas::BLAS_Precision precision;
2✔
184
    auto& datatype_oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
185
    types::PrimitiveType data_type = datatype_oedge.base_type().primitive_type();
2✔
186
    if (data_type == types::PrimitiveType::Float) {
2✔
187
        precision = math::blas::BLAS_Precision::s;
2✔
188
    } else {
2✔
NEW
189
        precision = math::blas::BLAS_Precision::d;
×
NEW
190
    }
×
191

192
    // Determine indvars
193
    symbolic::Symbol indvar_outer_1 = SymEngine::null, indvar_outer_2 = SymEngine::null, indvar_inner = SymEngine::null;
2✔
194
    symbolic::Expression m = SymEngine::null, n = SymEngine::null, k = SymEngine::null;
2✔
195
    std::vector<size_t> permutation = {0, 1, 2};
2✔
196
    do {
2✔
197
        if (symbolic::eq(this->einsum_node_.out_index(0), this->einsum_node_.indvar(permutation[0])) &&
2✔
198
            symbolic::eq(this->einsum_node_.out_index(1), this->einsum_node_.indvar(permutation[1]))) {
2✔
199
            indvar_outer_1 = this->einsum_node_.indvar(permutation[0]);
2✔
200
            indvar_outer_2 = this->einsum_node_.indvar(permutation[1]);
2✔
201
            indvar_inner = this->einsum_node_.indvar(permutation[2]);
2✔
202
            m = this->einsum_node_.bound(permutation[0]);
2✔
203
            n = this->einsum_node_.bound(permutation[1]);
2✔
204
            k = this->einsum_node_.bound(permutation[2]);
2✔
205
            break;
2✔
206
        }
2✔
207
    } while (std::next_permutation(permutation.begin(), permutation.end()));
2✔
208
    assert(
2✔
209
        !indvar_outer_1.is_null() && !indvar_outer_2.is_null() && !indvar_inner.is_null() && !m.is_null() &&
2✔
210
        !n.is_null() && !k.is_null()
2✔
211
    );
2✔
212

213
    // Determine inputs
214
    long long alpha = -1, A = -1, B = -1, C = -1;
2✔
215
    bool has_alpha = false;
2✔
216
    if (this->einsum_node_.inputs().size() == 3) {
2✔
217
        C = 2;
1✔
218
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
3✔
219
            if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
NEW
220
                break;
×
221
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
222
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
223
                A = i;
1✔
224
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
225
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
226
                B = i;
1✔
227
            }
1✔
228
        }
2✔
229
    } else if (this->einsum_node_.inputs().size() == 4) {
1✔
230
        C = 3;
1✔
231
        has_alpha = true;
1✔
232
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
4✔
233
            if (this->einsum_node_.in_indices(i).size() == 0) {
3✔
234
                alpha = i;
1✔
235
            } else if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
NEW
236
                break;
×
237
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
238
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
239
                A = i;
1✔
240
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
241
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
242
                B = i;
1✔
243
            }
1✔
244
        }
3✔
245
    }
1✔
246

247
    // Determine transpose and leading dimensions
248
    math::blas::BLAS_Transpose transA, transB;
2✔
249
    symbolic::Expression ldA, ldB, ldC;
2✔
250
    if (symbolic::eq(this->einsum_node_.in_index(A, 0), indvar_outer_1)) {
2✔
251
        transA = math::blas::BLAS_Transpose::No;
2✔
252
        ldA = k;
2✔
253
    } else {
2✔
NEW
254
        transA = math::blas::BLAS_Transpose::Trans;
×
NEW
255
        ldA = m;
×
NEW
256
    }
×
257
    if (symbolic::eq(this->einsum_node_.in_index(B, 1), indvar_outer_2)) {
2✔
258
        transB = math::blas::BLAS_Transpose::No;
2✔
259
        ldB = n;
2✔
260
    } else {
2✔
NEW
261
        transB = math::blas::BLAS_Transpose::Trans;
×
NEW
262
        ldB = k;
×
NEW
263
    }
×
264
    ldC = n;
2✔
265

266
    // Add the BLAS node for gemm
267
    auto& libnode = builder.add_library_node<math::blas::GEMMNode>(
2✔
268
        *block,
2✔
269
        this->einsum_node_.debug_info(),
2✔
270
        this->get_impl_type(data_type).value(),
2✔
271
        precision,
2✔
272
        math::blas::BLAS_Layout::RowMajor,
2✔
273
        transA,
2✔
274
        transB,
2✔
275
        m,
2✔
276
        n,
2✔
277
        k,
2✔
278
        ldA,
2✔
279
        ldB,
2✔
280
        ldC
2✔
281
    );
2✔
282

283
    // Copy the memlets
284
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
7✔
285
        if (has_alpha && iedge.dst_conn() == this->einsum_node_.input(alpha)) {
7✔
286
            builder.add_memlet(
1✔
287
                *block,
1✔
288
                iedge.src(),
1✔
289
                iedge.src_conn(),
1✔
290
                libnode,
1✔
291
                "__alpha",
1✔
292
                iedge.subset(),
1✔
293
                iedge.base_type(),
1✔
294
                iedge.debug_info()
1✔
295
            );
1✔
296
        } else if (iedge.dst_conn() == this->einsum_node_.input(A)) {
6✔
297
            builder.add_memlet(
2✔
298
                *block,
2✔
299
                iedge.src(),
2✔
300
                iedge.src_conn(),
2✔
301
                libnode,
2✔
302
                "__A",
2✔
303
                iedge.subset(),
2✔
304
                iedge.base_type(),
2✔
305
                iedge.debug_info()
2✔
306
            );
2✔
307
        } else if (iedge.dst_conn() == this->einsum_node_.input(B)) {
4✔
308
            builder.add_memlet(
2✔
309
                *block,
2✔
310
                iedge.src(),
2✔
311
                iedge.src_conn(),
2✔
312
                libnode,
2✔
313
                "__B",
2✔
314
                iedge.subset(),
2✔
315
                iedge.base_type(),
2✔
316
                iedge.debug_info()
2✔
317
            );
2✔
318
        } else if (iedge.dst_conn() == this->einsum_node_.input(C)) {
2✔
319
            builder.add_memlet(
2✔
320
                *block,
2✔
321
                iedge.src(),
2✔
322
                iedge.src_conn(),
2✔
323
                libnode,
2✔
324
                "__C",
2✔
325
                iedge.subset(),
2✔
326
                iedge.base_type(),
2✔
327
                iedge.debug_info()
2✔
328
            );
2✔
329
        }
2✔
330
    }
7✔
331
    for (auto& oedge : dfg.out_edges(this->einsum_node_)) {
2✔
332
        if (oedge.src_conn() == this->einsum_node_.output(0)) {
2✔
333
            builder.add_memlet(
2✔
334
                *block,
2✔
335
                libnode,
2✔
336
                "__C",
2✔
337
                oedge.dst(),
2✔
338
                oedge.dst_conn(),
2✔
339
                oedge.subset(),
2✔
340
                oedge.base_type(),
2✔
341
                oedge.debug_info()
2✔
342
            );
2✔
343
        }
2✔
344
    }
2✔
345

346
    // Remove the old memlets
347
    while (dfg.in_edges(this->einsum_node_).begin() != dfg.in_edges(this->einsum_node_).end()) {
9✔
348
        auto& iedge = *dfg.in_edges(this->einsum_node_).begin();
7✔
349
        builder.remove_memlet(*block, iedge);
7✔
350
    }
7✔
351
    while (dfg.out_edges(this->einsum_node_).begin() != dfg.out_edges(this->einsum_node_).end()) {
4✔
352
        auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
353
        builder.remove_memlet(*block, oedge);
2✔
354
    }
2✔
355

356
    // Add constant scalars alpha and beta (if needed)
357
    types::Scalar data_type_scalar(data_type);
2✔
358
    if (!has_alpha) {
2✔
359
        auto& alpha_access_node =
1✔
360
            builder.add_constant(*block, "1.0", data_type_scalar, this->einsum_node_.debug_info());
1✔
361
        builder.add_memlet(
1✔
362
            *block, alpha_access_node, "void", libnode, "__alpha", {}, data_type_scalar, this->einsum_node_.debug_info()
1✔
363
        );
1✔
364
    }
1✔
365
    auto& beta_access_node = builder.add_constant(*block, "1.0", data_type_scalar, this->einsum_node_.debug_info());
2✔
366
    builder.add_memlet(
2✔
367
        *block, beta_access_node, "void", libnode, "__beta", {}, data_type_scalar, this->einsum_node_.debug_info()
2✔
368
    );
2✔
369

370
    // Remove the einsum node
371
    builder.remove_node(*block, this->einsum_node_);
2✔
372

373
    analysis_manager.invalidate_all();
2✔
374
}
2✔
375

NEW
376
void Einsum2Gemm::to_json(nlohmann::json& j) const {
×
NEW
377
    j["transformation_type"] = this->name();
×
NEW
378
    j["einsum_node_element_id"] = this->einsum_node_.element_id();
×
NEW
379
    j["target_tune"] = this->target_tune_;
×
NEW
380
}
×
381

NEW
382
Einsum2Gemm Einsum2Gemm::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& j) {
×
NEW
383
    assert(j.contains("einsum_node_element_id"));
×
NEW
384
    assert(j["einsum_node_element_id"].is_number_unsigned());
×
NEW
385
    assert(j.contains("impl_type"));
×
386

NEW
387
    size_t einsum_node_id = j["einsum_node_element_id"].get<size_t>();
×
NEW
388
    auto* einsum_node_element = builder.find_element_by_id(einsum_node_id);
×
NEW
389
    if (!einsum_node_element) {
×
NEW
390
        throw InvalidTransformationDescriptionException(
×
NEW
391
            "Element with ID " + std::to_string(einsum_node_id) + " not found"
×
NEW
392
        );
×
NEW
393
    }
×
NEW
394
    auto* einsum_node = dynamic_cast<einsum::EinsumNode*>(einsum_node_element);
×
NEW
395
    if (!einsum_node) {
×
NEW
396
        throw InvalidTransformationDescriptionException(
×
NEW
397
            "Element with ID " + std::to_string(einsum_node_id) + " is not an EinsumNode"
×
NEW
398
        );
×
NEW
399
    }
×
400

NEW
401
    std::string target_tune;
×
NEW
402
    if (j.contains("target_tune")) {
×
NEW
403
        target_tune = j.at("target_tune").get<std::string>();
×
NEW
404
    } else {
×
NEW
405
        target_tune = "none";
×
NEW
406
    }
×
407

NEW
408
    return Einsum2Gemm(*einsum_node, target_tune);
×
NEW
409
}
×
410

411
} // namespace transformations
412
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc