• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20388902294

18 Dec 2025 07:41PM UTC coverage: 39.328% (-0.1%) from 39.454%
20388902294

push

github

web-flow
Merge pull request #400 from daisytuner/backward-symbol-propagation

adds backward symbol propagation

13427 of 44322 branches covered (30.29%)

Branch coverage included in aggregate %.

113 of 250 new or added lines in 4 files covered. (45.2%)

13 existing lines in 3 files now uncovered.

11565 of 19225 relevant lines covered (60.16%)

83.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

32.19
/src/passes/symbolic/symbol_propagation.cpp
1
#include "sdfg/passes/symbolic/symbol_propagation.h"
2

3
#include "sdfg/analysis/data_dependency_analysis.h"
4
#include "sdfg/analysis/dominance_analysis.h"
5
#include "sdfg/analysis/users.h"
6

7
namespace sdfg {
8
namespace passes {
9

10
symbolic::Expression inverse(const symbolic::Symbol lhs, const symbolic::Expression rhs) {
1✔
11
    if (!symbolic::uses(rhs, lhs)) {
1!
12
        return SymEngine::null;
1✔
13
    }
14

15
    if (SymEngine::is_a<SymEngine::Add>(*rhs)) {
×
16
        auto add = SymEngine::rcp_static_cast<const SymEngine::Add>(rhs);
×
17
        auto arg_0 = add->get_args()[0];
×
18
        auto arg_1 = add->get_args()[1];
×
19
        if (!symbolic::eq(arg_0, lhs)) {
×
20
            std::swap(arg_0, arg_1);
×
21
        }
×
22
        if (!symbolic::eq(arg_0, lhs)) {
×
23
            return SymEngine::null;
×
24
        }
25
        if (!SymEngine::is_a<SymEngine::Integer>(*arg_1)) {
×
26
            return SymEngine::null;
×
27
        }
28
        return symbolic::sub(lhs, arg_1);
×
29
    } else if (SymEngine::is_a<SymEngine::Mul>(*rhs)) {
×
30
        auto mul = SymEngine::rcp_static_cast<const SymEngine::Mul>(rhs);
×
31
        auto arg_0 = mul->get_args()[0];
×
32
        auto arg_1 = mul->get_args()[1];
×
33
        if (!symbolic::eq(arg_0, lhs)) {
×
34
            std::swap(arg_0, arg_1);
×
35
        }
×
36
        if (!symbolic::eq(arg_0, lhs)) {
×
37
            return SymEngine::null;
×
38
        }
39
        if (!SymEngine::is_a<SymEngine::Integer>(*arg_1)) {
×
40
            return SymEngine::null;
×
41
        }
42
        return symbolic::div(lhs, arg_1);
×
43
    }
×
44

45
    return SymEngine::null;
×
46
};
1✔
47

48
SymbolPropagation::SymbolPropagation()
12✔
49
    : Pass() {
12✔
50

51
      };
12✔
52

53
std::string SymbolPropagation::name() { return "SymbolPropagation"; };
×
54

55
bool SymbolPropagation::run_pass(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
12✔
56
    bool applied = false;
12✔
57

58
    auto& sdfg = builder.subject();
12✔
59
    auto& users = analysis_manager.get<analysis::Users>();
12✔
60
    auto& dominance_analysis = analysis_manager.get<analysis::DominanceAnalysis>();
12✔
61
    auto& data_dependency_analysis = analysis_manager.get<analysis::DataDependencyAnalysis>();
12✔
62
    std::unordered_set<data_flow::AccessNode*> replaced_nodes;
12✔
63
    std::unordered_set<std::string> skip;
12✔
64
    for (auto& name : sdfg.containers()) {
39!
65
        // Criterion: Only transients
66
        if (!sdfg.is_transient(name)) {
27!
67
            continue;
6✔
68
        }
69
        if (skip.find(name) != skip.end()) {
21!
NEW
70
            continue;
×
71
        }
72

73
        // Criterion: Only integers
74
        auto& type = builder.subject().type(name);
21!
75
        auto scalar = dynamic_cast<const types::Scalar*>(&type);
21!
76
        if (!scalar || !types::is_integer(scalar->primitive_type())) {
21!
77
            continue;
2✔
78
        }
79

80
        // The symbol will become the LHS (to be replaced)
81
        auto lhs = symbolic::symbol(name);
19!
82

83
        // Collect all reads of the symbol w.r.t to their writes
84
        auto raw_groups = data_dependency_analysis.defined_by(name);
19!
85
        for (auto& entry : raw_groups) {
39✔
86
            if (entry.first->use() != analysis::Use::READ) {
20!
UNCOV
87
                continue;
×
88
            }
89
            auto read = entry.first;
20✔
90

91
            // Reverse propagation
92
            if (entry.second.size() == 2) {
20✔
93
                // if (...) { a = 1} else { a = 2 } b = a
94
                // -> if (...) { a = 1, b = 1} else { a = 2, b = 2 }
95
                auto write1 = *entry.second.begin();
7✔
96
                auto write2 = *(++entry.second.begin());
7✔
97
                if (write1->container() != write2->container()) {
7!
UNCOV
98
                    continue;
×
99
                }
100
                auto rhs = symbolic::symbol(write1->container());
7!
101
                if (data_dependency_analysis.is_undefined_user(*write1) ||
14!
102
                    data_dependency_analysis.is_undefined_user(*write2)) {
7!
NEW
103
                    continue;
×
104
                }
105
                if (users.num_reads(write1->container()) != 1) {
7!
106
                    continue;
7✔
107
                }
NEW
108
                if (!dominance_analysis.post_dominates(*read, *write1) ||
×
NEW
109
                    !dominance_analysis.post_dominates(*read, *write2)) {
×
NEW
110
                    continue;
×
111
                }
NEW
112
                auto transition1 = dynamic_cast<structured_control_flow::Transition*>(write1->element());
×
NEW
113
                auto transition2 = dynamic_cast<structured_control_flow::Transition*>(write2->element());
×
NEW
114
                if (!transition1 || !transition2) {
×
NEW
115
                    continue;
×
116
                }
NEW
117
                auto transition_lhs = dynamic_cast<structured_control_flow::Transition*>(read->element());
×
NEW
118
                if (!transition_lhs) {
×
NEW
119
                    continue;
×
120
                }
NEW
121
                symbolic::Symbol lhs = SymEngine::null;
×
NEW
122
                for (auto& assign_entry : transition_lhs->assignments()) {
×
NEW
123
                    if (symbolic::eq(assign_entry.second, rhs)) {
×
NEW
124
                        lhs = assign_entry.first;
×
UNCOV
125
                        break;
×
126
                    }
127
                }
NEW
128
                if (lhs.is_null()) {
×
NEW
129
                    continue;
×
130
                }
NEW
131
                if (transition1->assignments().find(lhs) != transition1->assignments().end() ||
×
NEW
132
                    transition2->assignments().find(lhs) != transition2->assignments().end()) {
×
NEW
133
                    continue;
×
134
                }
135

NEW
136
                auto rhs1 = transition1->assignments().at(rhs);
×
NEW
137
                if (symbolic::uses(rhs1, lhs)) {
×
NEW
138
                    if (!symbolic::eq(rhs1, lhs)) {
×
NEW
139
                        continue;
×
140
                    }
UNCOV
141
                }
×
NEW
142
                auto rhs2 = transition2->assignments().at(rhs);
×
NEW
143
                if (symbolic::uses(rhs2, lhs)) {
×
NEW
144
                    if (!symbolic::eq(rhs2, lhs)) {
×
UNCOV
145
                        continue;
×
146
                    }
NEW
147
                }
×
NEW
148
                transition1->assignments().insert({lhs, rhs1});
×
NEW
149
                transition2->assignments().insert({lhs, rhs2});
×
NEW
150
                transition_lhs->assignments().erase(lhs);
×
NEW
151
                skip.insert(lhs->get_name());
×
NEW
152
                skip.insert(rhs->get_name());
×
NEW
153
                applied = true;
×
NEW
154
                break;
×
155
            }
7!
156
            // Forward propagation
157
            else if (entry.second.size() == 1) {
13!
158
                // Criterion: Write must be a transition
159
                auto write = *entry.second.begin();
13✔
160
                if (data_dependency_analysis.is_undefined_user(*write)) {
13!
NEW
161
                    continue;
×
162
                }
163
                auto transition = dynamic_cast<structured_control_flow::Transition*>(write->element());
13!
164
                if (!transition) {
13!
NEW
165
                    continue;
×
166
                }
167

168
                // We now define the rhs (to be propagated expression)
169
                if (transition->assignments().count(lhs) == 0) {
13!
170
                    // Reverse propagation already applied
NEW
171
                    continue;
×
172
                }
173
                auto rhs = transition->assignments().at(lhs);
13!
174

175
                // Criterion: RHS is not trivial and not recursive
176
                if (symbolic::eq(lhs, rhs) || symbolic::uses(rhs, lhs)) {
13!
NEW
177
                    continue;
×
178
                }
179

180
                // Criterion: Write dominates read to not cause data races
181
                if (!dominance_analysis.dominates(*write, *read)) {
13!
NEW
182
                    continue;
×
183
                }
184

185
                // Collect all symbols used in the RHS
186
                std::unordered_set<std::string> rhs_symbols;
13✔
187
                for (auto& sym : symbolic::atoms(rhs)) {
19!
188
                    if (symbolic::eq(sym, symbolic::__nullptr__())) {
6!
UNCOV
189
                        continue;
×
190
                    }
191
                    rhs_symbols.insert(sym->get_name());
6!
192
                }
193

194
                auto rhs_modified = rhs;
13!
195

196
                // Find dangerous users between write and read
197
                auto is_dangerous = [&](analysis::User* user) {
17✔
198
                    if (user == write || user == read) {
4!
NEW
199
                        return false;
×
200
                    }
201

202
                    // Criterion: RHS must dominate modification
203
                    if (!dominance_analysis.dominates(*write, *user)) {
4✔
204
                        return false;
1✔
205
                    }
206

207
                    // Criterion: Modification must dominate read
208
                    if (dominance_analysis.dominates(*read, *user)) {
3✔
209
                        return false;
2✔
210
                    }
211

212
                    return true;
1✔
213
                };
4✔
214
                std::unordered_set<std::string> dangerous_users;
13✔
215
                for (const auto& sym : rhs_symbols) {
19✔
216
                    for (auto* user : users.writes(sym)) {
9!
217
                        if (is_dangerous(user)) {
4!
218
                            dangerous_users.insert(sym);
1!
219
                            break;
1✔
220
                        }
221
                    }
222
                    for (auto* user : users.moves(sym)) {
6!
NEW
223
                        if (is_dangerous(user)) {
×
NEW
224
                            dangerous_users.insert(sym);
×
NEW
225
                            break;
×
226
                        }
227
                    }
228
                }
229
                if (!dangerous_users.empty()) {
13✔
230
                    // RHS' symbols may be written between write and read
231
                    // We attempt to create the new RHS
232
                    bool success = true;
1✔
233
                    std::unordered_set<std::string> modified_symbols;
1✔
234
                    auto middle_users = users.all_uses_between(*write, *read);
1!
235
                    for (auto& user : middle_users) {
1!
236
                        if (user->use() != analysis::Use::WRITE && user->use() != analysis::Use::MOVE) {
1!
NEW
237
                            continue;
×
238
                        }
239
                        if (rhs_symbols.find(user->container()) == rhs_symbols.end()) {
1!
NEW
240
                            continue;
×
241
                        }
242

243
                        // Criterion: Symbol is only modified once
244
                        if (modified_symbols.find(user->container()) != modified_symbols.end()) {
1!
245
                            success = false;
×
246
                            break;
×
247
                        }
248

249
                        // Criterion: RHS must dominate modification
250
                        if (!dominance_analysis.dominates(*write, *user)) {
1!
NEW
251
                            success = false;
×
NEW
252
                            break;
×
253
                        }
254

255
                        // Criterion: Modification must dominate read
256
                        if (!dominance_analysis.dominates(*user, *read)) {
1!
NEW
257
                            success = false;
×
NEW
258
                            break;
×
259
                        }
260

261
                        // Criterion: Only transitions
262
                        if (!dynamic_cast<structured_control_flow::Transition*>(user->element())) {
1!
NEW
263
                            success = false;
×
NEW
264
                            break;
×
265
                        }
266
                        auto sym_transition = dynamic_cast<structured_control_flow::Transition*>(user->element());
1!
267
                        auto sym_lhs = symbolic::symbol(user->container());
1!
268
                        auto sym_rhs = sym_transition->assignments().at(sym_lhs);
1!
269

270
                        // Limited to constants
271
                        for (auto& atom : symbolic::atoms(sym_rhs)) {
1!
NEW
272
                            if (!symbolic::eq(atom, sym_lhs)) {
×
NEW
273
                                success = false;
×
NEW
274
                                break;
×
275
                            }
276
                        }
277
                        if (!success) {
1!
NEW
278
                            break;
×
279
                        }
280

281
                        auto inv = inverse(sym_lhs, sym_rhs);
1!
282
                        if (inv == SymEngine::null) {
1!
283
                            success = false;
1✔
284
                            break;
1✔
285
                        }
286

NEW
287
                        rhs_modified = symbolic::subs(rhs_modified, sym_lhs, inv);
×
NEW
288
                        modified_symbols.insert(user->container());
×
289
                    }
1!
290
                    if (!success) {
1!
291
                        continue;
1✔
292
                    }
293
                }
1!
294
                rhs_modified = symbolic::simplify(rhs_modified);
12!
295

296
                if (auto transition_stmt = dynamic_cast<structured_control_flow::Transition*>(read->element())) {
12!
297
                    auto& assignments = transition_stmt->assignments();
7!
298
                    for (auto& entry : assignments) {
14✔
299
                        if (symbolic::uses(entry.second, lhs)) {
7!
300
                            entry.second = symbolic::subs(entry.second, lhs, rhs_modified);
7!
301
                            applied = true;
7✔
302
                        }
7✔
303
                    }
304
                } else if (auto if_else_stmt = dynamic_cast<structured_control_flow::IfElse*>(read->element())) {
12!
305
                    // Criterion: RHS does not use nvptx symbols
NEW
306
                    bool nvptx = false;
×
NEW
307
                    for (auto& atom : symbolic::atoms(rhs_modified)) {
×
NEW
308
                        if (symbolic::is_nv(atom)) {
×
NEW
309
                            nvptx = true;
×
NEW
310
                            break;
×
311
                        }
312
                    }
NEW
313
                    if (nvptx) {
×
UNCOV
314
                        continue;
×
315
                    }
316

NEW
317
                    for (size_t i = 0; i < if_else_stmt->size(); i++) {
×
NEW
318
                        auto child = if_else_stmt->at(i);
×
NEW
319
                        if (symbolic::uses(child.second, lhs)) {
×
NEW
320
                            builder.update_if_else_condition(
×
NEW
321
                                *if_else_stmt, i, symbolic::subs(child.second, lhs, rhs_modified)
×
322
                            );
NEW
323
                            applied = true;
×
NEW
324
                        }
×
325
                    }
×
326
                } else if (auto memlet = dynamic_cast<data_flow::Memlet*>(read->element())) {
5!
327
                    bool used = false;
2✔
328
                    auto subset = memlet->subset();
2!
329
                    for (auto& dim : subset) {
4✔
330
                        if (symbolic::uses(dim, lhs)) {
2!
331
                            dim = symbolic::subs(dim, lhs, rhs_modified);
2!
332
                            used = true;
2✔
333
                        }
2✔
334
                    }
335
                    if (used) {
2!
336
                        memlet->set_subset(subset);
2!
337
                        applied = true;
2✔
338
                    }
2✔
339
                } else if (auto access_node = dynamic_cast<data_flow::AccessNode*>(read->element())) {
5!
NEW
340
                    if (SymEngine::is_a<SymEngine::Symbol>(*rhs_modified)) {
×
NEW
341
                        auto new_symbol = SymEngine::rcp_static_cast<const SymEngine::Symbol>(rhs_modified);
×
NEW
342
                        if (symbolic::is_nullptr(new_symbol) ||
×
NEW
343
                            sdfg.type(new_symbol->get_name()).type_id() == types::TypeID::Pointer) {
×
NEW
344
                            continue;
×
345
                        }
NEW
346
                        access_node->data(new_symbol->get_name());
×
NEW
347
                        applied = true;
×
NEW
348
                        replaced_nodes.insert(access_node);
×
NEW
349
                    } else if (SymEngine::is_a<SymEngine::Integer>(*rhs_modified)) {
×
NEW
350
                        auto new_int = SymEngine::rcp_static_cast<const SymEngine::Integer>(rhs_modified);
×
NEW
351
                        auto& graph = access_node->get_parent();
×
NEW
352
                        auto block = static_cast<structured_control_flow::Block*>(graph.get_parent());
×
353

354
                        // Replace with const node
NEW
355
                        auto& const_node =
×
NEW
356
                            builder
×
NEW
357
                                .add_constant(*block, std::to_string(new_int->as_int()), type, access_node->debug_info());
×
358

NEW
359
                        std::unordered_set<data_flow::Memlet*> replace_edges;
×
NEW
360
                        for (auto& oedge : graph.out_edges(*access_node)) {
×
NEW
361
                            builder.add_memlet(
×
NEW
362
                                *block,
×
NEW
363
                                const_node,
×
NEW
364
                                oedge.src_conn(),
×
NEW
365
                                oedge.dst(),
×
NEW
366
                                oedge.dst_conn(),
×
NEW
367
                                oedge.subset(),
×
NEW
368
                                oedge.base_type(),
×
NEW
369
                                oedge.debug_info()
×
370
                            );
NEW
371
                            replace_edges.insert(&oedge);
×
372
                        }
NEW
373
                        for (auto& iedge : graph.in_edges(*access_node)) {
×
NEW
374
                            builder.add_memlet(
×
NEW
375
                                *block,
×
NEW
376
                                iedge.src(),
×
NEW
377
                                iedge.src_conn(),
×
NEW
378
                                const_node,
×
NEW
379
                                iedge.dst_conn(),
×
NEW
380
                                iedge.subset(),
×
NEW
381
                                iedge.base_type(),
×
NEW
382
                                iedge.debug_info()
×
383
                            );
384
                        }
385

NEW
386
                        for (auto& edge : replace_edges) {
×
NEW
387
                            builder.remove_memlet(*block, *edge);
×
388
                        }
NEW
389
                        builder.remove_node(*block, *access_node);
×
NEW
390
                        applied = true;
×
391
                    }
×
392
                } else if (auto library_node = dynamic_cast<data_flow::LibraryNode*>(read->element())) {
3!
NEW
393
                    for (auto& symbol : library_node->symbols()) {
×
NEW
394
                        if (symbolic::eq(symbol, lhs)) {
×
NEW
395
                            library_node->replace(symbol, rhs_modified);
×
NEW
396
                            applied = true;
×
NEW
397
                        }
×
398
                    }
399
                } else if (auto for_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(read->element())) {
3!
400
                    auto for_user = dynamic_cast<analysis::ForUser*>(read);
3!
401
                    if (for_user->is_init() && symbolic::uses(for_loop->init(), lhs)) {
4!
402
                        auto new_init = symbolic::subs(for_loop->init(), lhs, rhs_modified);
1!
403
                        new_init = symbolic::simplify(new_init);
1!
404
                        builder.update_loop(
2!
405
                            *for_loop, for_loop->indvar(), for_loop->condition(), new_init, for_loop->update()
1!
406
                        );
407
                        applied = true;
1✔
408
                    } else if (for_user->is_condition() && symbolic::uses(for_loop->condition(), lhs)) {
4!
409
                        auto new_condition = symbolic::subs(for_loop->condition(), lhs, rhs_modified);
1!
410
                        new_condition =
1!
411
                            SymEngine::rcp_dynamic_cast<const SymEngine::Boolean>(symbolic::simplify(new_condition));
1!
412
                        builder.update_loop(
2!
413
                            *for_loop, for_loop->indvar(), new_condition, for_loop->init(), for_loop->update()
1!
414
                        );
415
                        applied = true;
1✔
416
                    } else if (for_user->is_update() && symbolic::uses(for_loop->update(), lhs)) {
3!
417
                        auto new_update = symbolic::subs(for_loop->update(), lhs, rhs_modified);
1!
418
                        new_update = symbolic::simplify(new_update);
1!
419
                        builder.update_loop(
2!
420
                            *for_loop, for_loop->indvar(), for_loop->condition(), for_loop->init(), new_update
1!
421
                        );
422
                        applied = true;
1✔
423
                    }
1✔
424
                }
3✔
425
            }
13✔
426
        }
427
    }
19✔
428

429
    // Post-processing: Merge access nodes and remove dangling nodes
430
    // Avoid removing elements while iterating above
431
    for (auto* node : replaced_nodes) {
12!
432
        builder.merge_siblings(*node);
×
433
    }
434
    for (auto* node : replaced_nodes) {
12!
435
        auto& graph = node->get_parent();
×
436
        auto* block = static_cast<structured_control_flow::Block*>(graph.get_parent());
×
437
        for (auto& dnode : graph.data_nodes()) {
×
438
            if (graph.in_degree(*dnode) == 0 && graph.out_degree(*dnode) == 0) {
×
439
                builder.remove_node(*block, *dnode);
×
440
            }
×
441
        }
442
    }
443

444
    return applied;
12✔
445
};
12✔
446

447
} // namespace passes
448
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc