• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23490543373

24 Mar 2026 12:56PM UTC coverage: 64.456% (+0.2%) from 64.295%
23490543373

Pull #605

github

web-flow
Merge 28bc2690b into e56781552
Pull Request #605: Move einsum support

1303 of 1918 new or added lines in 14 files covered. (67.94%)

45 existing lines in 3 files now uncovered.

27952 of 43366 relevant lines covered (64.46%)

392.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.61
/opt/src/transformations/einsum2gemm.cpp
1
#include "sdfg/transformations/einsum2gemm.h"
2

3
#include <algorithm>
4
#include <cassert>
5
#include <cstddef>
6
#include <nlohmann/json_fwd.hpp>
7
#include <optional>
8
#include <string>
9
#include <vector>
10

11
#include "sdfg/analysis/analysis.h"
12
#include "sdfg/builder/structured_sdfg_builder.h"
13
#include "sdfg/data_flow/library_node.h"
14
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
15
#include "sdfg/data_flow/library_nodes/math/math.h"
16
#include "sdfg/einsum/einsum.h"
17
#include "sdfg/symbolic/symbolic.h"
18
#include "sdfg/targets/tenstorrent/library_node_mapping.h"
19
#include "sdfg/transformations/transformation.h"
20
#include "sdfg/types/scalar.h"
21
#include "sdfg/types/type.h"
22
#include "symengine/symengine_rcp.h"
23

24
namespace sdfg {
25
namespace transformations {
26

27
bool Einsum2Gemm::check_matrix_indices(long long mat, const symbolic::Symbol& indvar1, const symbolic::Symbol& indvar2) {
4✔
28
    return !symbolic::eq(this->einsum_node_.in_index(mat, 0), this->einsum_node_.in_index(mat, 1)) &&
4✔
29
           (symbolic::eq(this->einsum_node_.in_index(mat, 0), indvar1) ||
4✔
30
            symbolic::eq(this->einsum_node_.in_index(mat, 0), indvar2)) &&
4✔
31
           (symbolic::eq(this->einsum_node_.in_index(mat, 1), indvar1) ||
4✔
32
            symbolic::eq(this->einsum_node_.in_index(mat, 1), indvar2));
4✔
33
}
4✔
34

35
Einsum2Gemm::Einsum2Gemm(einsum::EinsumNode& einsum_node, const std::string& target_tune)
36
    : einsum_node_(einsum_node), target_tune_(target_tune) {}
2✔
37

NEW
38
std::string Einsum2Gemm::name() const { return "Einsum2Gemm"; }
×
39

40
std::optional<sdfg::data_flow::ImplementationType> Einsum2Gemm::get_impl_type(types::PrimitiveType data_type) {
4✔
41
    std::optional<sdfg::data_flow::ImplementationType> impl_type = std::nullopt;
4✔
42
    if (this->target_tune_ == "openmp") {
4✔
43
        impl_type = std::make_optional(sdfg::math::blas::ImplementationType_BLAS);
4✔
44
    } else if (this->target_tune_ == "tenstorrent") {
4✔
NEW
45
        impl_type = tenstorrent::try_map_library_node_implementation(math::blas::LibraryNodeType_GEMM, data_type);
×
NEW
46
    }
×
47
    // TODO: Implement GEMM dispatcher for CUBLAS
48

49
    return impl_type;
4✔
50
}
4✔
51

52
bool Einsum2Gemm::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
53
    // Check dims
54
    if (this->einsum_node_.dims().size() != 3) {
2✔
NEW
55
        return false;
×
NEW
56
    }
×
57

58
    // Check initial values
59
    for (size_t i = 0; i < 3; i++) {
8✔
60
        if (!symbolic::eq(this->einsum_node_.init(i), symbolic::zero())) {
6✔
NEW
61
            return false;
×
NEW
62
        }
×
63
    }
6✔
64

65
    // Check out indices
66
    if (this->einsum_node_.out_indices().size() != 2) {
2✔
NEW
67
        return false;
×
NEW
68
    }
×
69
    symbolic::Symbol indvar_outer_1 = SymEngine::null, indvar_outer_2 = SymEngine::null, indvar_inner = SymEngine::null;
2✔
70
    std::vector<size_t> permutation = {0, 1, 2};
2✔
71
    do {
2✔
72
        if (symbolic::eq(this->einsum_node_.out_index(0), this->einsum_node_.indvar(permutation[0])) &&
2✔
73
            symbolic::eq(this->einsum_node_.out_index(1), this->einsum_node_.indvar(permutation[1]))) {
2✔
74
            indvar_outer_1 = this->einsum_node_.indvar(permutation[0]);
2✔
75
            indvar_outer_2 = this->einsum_node_.indvar(permutation[1]);
2✔
76
            indvar_inner = this->einsum_node_.indvar(permutation[2]);
2✔
77
            break;
2✔
78
        }
2✔
79
    } while (std::next_permutation(permutation.begin(), permutation.end()));
2✔
80
    if (indvar_outer_1.is_null() || indvar_outer_2.is_null() || indvar_inner.is_null()) {
2✔
NEW
81
        return false;
×
NEW
82
    }
×
83

84
    // Check bounds, i.e., preven triangular access
85
    for (size_t i = 0; i < 3; i++) {
8✔
86
        if (symbolic::uses(this->einsum_node_.bound(i), indvar_outer_1) ||
6✔
87
            symbolic::uses(this->einsum_node_.bound(i), indvar_outer_2) ||
6✔
88
            symbolic::uses(this->einsum_node_.bound(i), indvar_inner)) {
6✔
NEW
89
            return false;
×
NEW
90
        }
×
91
    }
6✔
92

93
    // Check inputs
94
    long long A = -1, B = -1, C = -1;
2✔
95
    if (this->einsum_node_.inputs().size() == 3) {
2✔
96
        C = 2;
1✔
97
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
3✔
98
            if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
NEW
99
                break;
×
100
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
101
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
102
                A = i;
1✔
103
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
104
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
105
                B = i;
1✔
106
            }
1✔
107
        }
2✔
108
    } else if (this->einsum_node_.inputs().size() == 4) {
1✔
109
        C = 3;
1✔
110
        long long alpha = -1;
1✔
111
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
4✔
112
            if (this->einsum_node_.in_indices(i).size() == 0) {
3✔
113
                alpha = i;
1✔
114
            } else if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
NEW
115
                break;
×
116
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
117
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
118
                A = i;
1✔
119
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
120
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
121
                B = i;
1✔
122
            }
1✔
123
        }
3✔
124

125
        // Check alpha
126
        if (alpha == -1 || this->einsum_node_.in_indices(alpha).size() != 0) {
1✔
NEW
127
            return false;
×
NEW
128
        }
×
129
    } else {
1✔
NEW
130
        return false;
×
NEW
131
    }
×
132
    if (A == -1 || B == -1 || A == B || this->einsum_node_.input(C) != this->einsum_node_.output(0)) {
2✔
NEW
133
        return false;
×
NEW
134
    }
×
135

136
    // Check in indices
137
    if (this->einsum_node_.in_indices(A).size() != 2 || !this->check_matrix_indices(A, indvar_outer_1, indvar_inner)) {
2✔
NEW
138
        return false;
×
NEW
139
    }
×
140
    if (this->einsum_node_.in_indices(B).size() != 2 || !this->check_matrix_indices(B, indvar_inner, indvar_outer_2)) {
2✔
NEW
141
        return false;
×
NEW
142
    }
×
143
    if (this->einsum_node_.in_indices(C).size() != 2 ||
2✔
144
        !symbolic::eq(this->einsum_node_.in_index(C, 0), indvar_outer_1) ||
2✔
145
        !symbolic::eq(this->einsum_node_.in_index(C, 1), indvar_outer_2)) {
2✔
NEW
146
        return false;
×
NEW
147
    }
×
148

149
    // Get the data flow graph
150
    auto& dfg = this->einsum_node_.get_parent();
2✔
151

152
    // Determine and check the base type of output
153
    auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
154
    auto data_type = oedge.base_type().primitive_type();
2✔
155
    if (data_type != types::PrimitiveType::Float && data_type != types::PrimitiveType::Double) {
2✔
NEW
156
        return false;
×
NEW
157
    }
×
158

159
    // Check if all inputs have the same primitive type
160
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
7✔
161
        if (iedge.base_type().primitive_type() != data_type) {
7✔
NEW
162
            return false;
×
NEW
163
        }
×
164
    }
7✔
165

166
    if (!this->get_impl_type(data_type)) { // no implementation for the given tune exists
2✔
NEW
167
        return false;
×
NEW
168
    }
×
169

170
    return true;
2✔
171
}
2✔
172

173
void Einsum2Gemm::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
174
    // Get the data flow graph
175
    auto& dfg = this->einsum_node_.get_parent();
2✔
176

177
    // Get the block in which the einsum node lives
178
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
2✔
179
    assert(block);
2✔
180

181
    // Determine the BLAS precision
182
    math::blas::BLAS_Precision precision;
2✔
183
    auto& datatype_oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
184
    types::PrimitiveType data_type = datatype_oedge.base_type().primitive_type();
2✔
185
    if (data_type == types::PrimitiveType::Float) {
2✔
186
        precision = math::blas::BLAS_Precision::s;
2✔
187
    } else {
2✔
NEW
188
        precision = math::blas::BLAS_Precision::d;
×
NEW
189
    }
×
190

191
    // Determine indvars
192
    symbolic::Symbol indvar_outer_1 = SymEngine::null, indvar_outer_2 = SymEngine::null, indvar_inner = SymEngine::null;
2✔
193
    symbolic::Expression m = SymEngine::null, n = SymEngine::null, k = SymEngine::null;
2✔
194
    std::vector<size_t> permutation = {0, 1, 2};
2✔
195
    do {
2✔
196
        if (symbolic::eq(this->einsum_node_.out_index(0), this->einsum_node_.indvar(permutation[0])) &&
2✔
197
            symbolic::eq(this->einsum_node_.out_index(1), this->einsum_node_.indvar(permutation[1]))) {
2✔
198
            indvar_outer_1 = this->einsum_node_.indvar(permutation[0]);
2✔
199
            indvar_outer_2 = this->einsum_node_.indvar(permutation[1]);
2✔
200
            indvar_inner = this->einsum_node_.indvar(permutation[2]);
2✔
201
            m = this->einsum_node_.bound(permutation[0]);
2✔
202
            n = this->einsum_node_.bound(permutation[1]);
2✔
203
            k = this->einsum_node_.bound(permutation[2]);
2✔
204
            break;
2✔
205
        }
2✔
206
    } while (std::next_permutation(permutation.begin(), permutation.end()));
2✔
207
    assert(
2✔
208
        !indvar_outer_1.is_null() && !indvar_outer_2.is_null() && !indvar_inner.is_null() && !m.is_null() &&
2✔
209
        !n.is_null() && !k.is_null()
2✔
210
    );
2✔
211

212
    // Determine inputs
213
    long long alpha = -1, A = -1, B = -1, C = -1;
2✔
214
    bool has_alpha = false;
2✔
215
    if (this->einsum_node_.inputs().size() == 3) {
2✔
216
        C = 2;
1✔
217
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
3✔
218
            if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
NEW
219
                break;
×
220
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
221
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
222
                A = i;
1✔
223
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
224
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
225
                B = i;
1✔
226
            }
1✔
227
        }
2✔
228
    } else if (this->einsum_node_.inputs().size() == 4) {
1✔
229
        C = 3;
1✔
230
        has_alpha = true;
1✔
231
        for (size_t i = 0; i < this->einsum_node_.in_indices().size() - 1; i++) {
4✔
232
            if (this->einsum_node_.in_indices(i).size() == 0) {
3✔
233
                alpha = i;
1✔
234
            } else if (this->einsum_node_.in_indices(i).size() != 2) {
2✔
NEW
235
                break;
×
236
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_1) ||
2✔
237
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_1)) {
2✔
238
                A = i;
1✔
239
            } else if (symbolic::eq(this->einsum_node_.in_index(i, 0), indvar_outer_2) ||
1✔
240
                       symbolic::eq(this->einsum_node_.in_index(i, 1), indvar_outer_2)) {
1✔
241
                B = i;
1✔
242
            }
1✔
243
        }
3✔
244
    }
1✔
245

246
    // Determine transpose and leading dimensions
247
    math::blas::BLAS_Transpose transA, transB;
2✔
248
    symbolic::Expression ldA, ldB, ldC;
2✔
249
    if (symbolic::eq(this->einsum_node_.in_index(A, 0), indvar_outer_1)) {
2✔
250
        transA = math::blas::BLAS_Transpose::No;
2✔
251
        ldA = k;
2✔
252
    } else {
2✔
NEW
253
        transA = math::blas::BLAS_Transpose::Trans;
×
NEW
254
        ldA = m;
×
NEW
255
    }
×
256
    if (symbolic::eq(this->einsum_node_.in_index(B, 1), indvar_outer_2)) {
2✔
257
        transB = math::blas::BLAS_Transpose::No;
2✔
258
        ldB = n;
2✔
259
    } else {
2✔
NEW
260
        transB = math::blas::BLAS_Transpose::Trans;
×
NEW
261
        ldB = k;
×
NEW
262
    }
×
263
    ldC = n;
2✔
264

265
    // Add the BLAS node for gemm
266
    auto& libnode = builder.add_library_node<math::blas::GEMMNode>(
2✔
267
        *block,
2✔
268
        this->einsum_node_.debug_info(),
2✔
269
        this->get_impl_type(data_type).value(),
2✔
270
        precision,
2✔
271
        math::blas::BLAS_Layout::RowMajor,
2✔
272
        transA,
2✔
273
        transB,
2✔
274
        m,
2✔
275
        n,
2✔
276
        k,
2✔
277
        ldA,
2✔
278
        ldB,
2✔
279
        ldC
2✔
280
    );
2✔
281

282
    // Copy the memlets
283
    for (auto& iedge : dfg.in_edges(this->einsum_node_)) {
7✔
284
        if (has_alpha && iedge.dst_conn() == this->einsum_node_.input(alpha)) {
7✔
285
            builder.add_memlet(
1✔
286
                *block,
1✔
287
                iedge.src(),
1✔
288
                iedge.src_conn(),
1✔
289
                libnode,
1✔
290
                "__alpha",
1✔
291
                iedge.subset(),
1✔
292
                iedge.base_type(),
1✔
293
                iedge.debug_info()
1✔
294
            );
1✔
295
        } else if (iedge.dst_conn() == this->einsum_node_.input(A)) {
6✔
296
            builder.add_memlet(
2✔
297
                *block,
2✔
298
                iedge.src(),
2✔
299
                iedge.src_conn(),
2✔
300
                libnode,
2✔
301
                "__A",
2✔
302
                iedge.subset(),
2✔
303
                iedge.base_type(),
2✔
304
                iedge.debug_info()
2✔
305
            );
2✔
306
        } else if (iedge.dst_conn() == this->einsum_node_.input(B)) {
4✔
307
            builder.add_memlet(
2✔
308
                *block,
2✔
309
                iedge.src(),
2✔
310
                iedge.src_conn(),
2✔
311
                libnode,
2✔
312
                "__B",
2✔
313
                iedge.subset(),
2✔
314
                iedge.base_type(),
2✔
315
                iedge.debug_info()
2✔
316
            );
2✔
317
        } else if (iedge.dst_conn() == this->einsum_node_.input(C)) {
2✔
318
            builder.add_memlet(
2✔
319
                *block,
2✔
320
                iedge.src(),
2✔
321
                iedge.src_conn(),
2✔
322
                libnode,
2✔
323
                "__C",
2✔
324
                iedge.subset(),
2✔
325
                iedge.base_type(),
2✔
326
                iedge.debug_info()
2✔
327
            );
2✔
328
        }
2✔
329
    }
7✔
330
    for (auto& oedge : dfg.out_edges(this->einsum_node_)) {
2✔
331
        if (oedge.src_conn() == this->einsum_node_.output(0)) {
2✔
332
            builder.add_memlet(
2✔
333
                *block,
2✔
334
                libnode,
2✔
335
                "__C",
2✔
336
                oedge.dst(),
2✔
337
                oedge.dst_conn(),
2✔
338
                oedge.subset(),
2✔
339
                oedge.base_type(),
2✔
340
                oedge.debug_info()
2✔
341
            );
2✔
342
        }
2✔
343
    }
2✔
344

345
    // Remove the old memlets
346
    while (dfg.in_edges(this->einsum_node_).begin() != dfg.in_edges(this->einsum_node_).end()) {
9✔
347
        auto& iedge = *dfg.in_edges(this->einsum_node_).begin();
7✔
348
        builder.remove_memlet(*block, iedge);
7✔
349
    }
7✔
350
    while (dfg.out_edges(this->einsum_node_).begin() != dfg.out_edges(this->einsum_node_).end()) {
4✔
351
        auto& oedge = *dfg.out_edges(this->einsum_node_).begin();
2✔
352
        builder.remove_memlet(*block, oedge);
2✔
353
    }
2✔
354

355
    // Add constant scalars alpha and beta (if needed)
356
    types::Scalar data_type_scalar(data_type);
2✔
357
    if (!has_alpha) {
2✔
358
        auto& alpha_access_node =
1✔
359
            builder.add_constant(*block, "1.0", data_type_scalar, this->einsum_node_.debug_info());
1✔
360
        builder.add_memlet(
1✔
361
            *block, alpha_access_node, "void", libnode, "__alpha", {}, data_type_scalar, this->einsum_node_.debug_info()
1✔
362
        );
1✔
363
    }
1✔
364
    auto& beta_access_node = builder.add_constant(*block, "1.0", data_type_scalar, this->einsum_node_.debug_info());
2✔
365
    builder.add_memlet(
2✔
366
        *block, beta_access_node, "void", libnode, "__beta", {}, data_type_scalar, this->einsum_node_.debug_info()
2✔
367
    );
2✔
368

369
    // Remove the einsum node
370
    builder.remove_node(*block, this->einsum_node_);
2✔
371

372
    analysis_manager.invalidate_all();
2✔
373
}
2✔
374

NEW
375
void Einsum2Gemm::to_json(nlohmann::json& j) const {
×
NEW
376
    j["transformation_type"] = this->name();
×
NEW
377
    j["einsum_node_element_id"] = this->einsum_node_.element_id();
×
NEW
378
    j["target_tune"] = this->target_tune_;
×
NEW
379
}
×
380

NEW
381
Einsum2Gemm Einsum2Gemm::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& j) {
×
NEW
382
    assert(j.contains("einsum_node_element_id"));
×
NEW
383
    assert(j["einsum_node_element_id"].is_number_unsigned());
×
NEW
384
    assert(j.contains("impl_type"));
×
385

NEW
386
    size_t einsum_node_id = j["einsum_node_element_id"].get<size_t>();
×
NEW
387
    auto* einsum_node_element = builder.find_element_by_id(einsum_node_id);
×
NEW
388
    if (!einsum_node_element) {
×
NEW
389
        throw InvalidTransformationDescriptionException(
×
NEW
390
            "Element with ID " + std::to_string(einsum_node_id) + " not found"
×
NEW
391
        );
×
NEW
392
    }
×
NEW
393
    auto* einsum_node = dynamic_cast<einsum::EinsumNode*>(einsum_node_element);
×
NEW
394
    if (!einsum_node) {
×
NEW
395
        throw InvalidTransformationDescriptionException(
×
NEW
396
            "Element with ID " + std::to_string(einsum_node_id) + " is not an EinsumNode"
×
NEW
397
        );
×
NEW
398
    }
×
399

NEW
400
    std::string target_tune;
×
NEW
401
    if (j.contains("target_tune")) {
×
NEW
402
        target_tune = j.at("target_tune").get<std::string>();
×
NEW
403
    } else {
×
NEW
404
        target_tune = "none";
×
NEW
405
    }
×
406

NEW
407
    return Einsum2Gemm(*einsum_node, target_tune);
×
NEW
408
}
×
409

410
} // namespace transformations
411
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc