• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27623444566

16 Jun 2026 02:06PM UTC coverage: 61.524% (-0.02%) from 61.54%
27623444566

push

github

web-flow
Merge pull request #767 from daisytuner/loop-analysis-stacks

Extending LoopAnalysis to provide data for partially perfect loop nests

333 of 398 new or added lines in 4 files covered. (83.67%)

9 existing lines in 2 files now uncovered.

36494 of 59317 relevant lines covered (61.52%)

1107.09 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.96
/opt/src/transformations/map_fusion.cpp
1
#include "sdfg/transformations/map_fusion.h"
2

3
#include <isl/ctx.h>
4
#include <isl/map.h>
5
#include <isl/options.h>
6
#include <isl/set.h>
7
#include <isl/space.h>
8
#include <symengine/solve.h>
9
#include "sdfg/analysis/arguments_analysis.h"
10
#include "sdfg/analysis/loop_analysis.h"
11

12
#include "sdfg/analysis/assumptions_analysis.h"
13
#include "sdfg/control_flow/interstate_edge.h"
14
#include "sdfg/data_flow/data_flow_graph.h"
15
#include "sdfg/structured_control_flow/block.h"
16
#include "sdfg/structured_control_flow/for.h"
17
#include "sdfg/symbolic/delinearization.h"
18
#include "sdfg/symbolic/utils.h"
19
#include "sdfg/visitor/structured_sdfg_visitor.h"
20

21
namespace sdfg {
22
namespace transformations {
23

24
class FusionConsumerSubsetVisitor : public visitor::ActualStructuredSDFGVisitor {
25
    friend MapFusion;
26

27
    std::unordered_map<std::string, const data_flow::Subset*>& target_containers_;
28
    std::unordered_map<std::string, std::vector<data_flow::Subset>> unique_subsets_per_container_;
29

30
protected:
31
    bool abort() { return true; }
1✔
32

33
public:
34
    FusionConsumerSubsetVisitor(std::unordered_map<std::string, const data_flow::Subset*>& target_containers)
35
        : target_containers_(target_containers) {}
33✔
36

37
    bool visit(sdfg::structured_control_flow::Block& block) override {
34✔
38
        auto& dataflow = block.dataflow();
34✔
39
        for (auto& node : dataflow.nodes()) {
119✔
40
            auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
119✔
41
            if (access == nullptr) {
119✔
42
                continue;
36✔
43
            }
36✔
44
            auto& container = access->data();
83✔
45

46
            auto target_it = target_containers_.find(container);
83✔
47
            if (target_it == target_containers_.end()) {
83✔
48
                continue;
44✔
49
            }
44✔
50
            auto& producer_subset = *target_it->second;
39✔
51
            auto& unique_subsets = unique_subsets_per_container_[container]; // Ensures entry exists
39✔
52

53
            // Skip write-only access nodes (consumer also writes the fusion container)
54
            if (dataflow.in_degree(*access) > 0 && dataflow.out_degree(*access) == 0) {
39✔
55
                continue;
2✔
56
            }
2✔
57
            if (dataflow.in_degree(*access) != 0 || dataflow.out_degree(*access) == 0) {
37✔
NEW
58
                return abort();
×
NEW
59
            }
×
60

61
            // Check all read memlets from this access
62
            for (auto& memlet : dataflow.out_edges(*access)) {
39✔
63
                if (memlet.type() != data_flow::MemletType::Computational) {
39✔
NEW
64
                    return abort();
×
NEW
65
                }
×
66

67
                auto& consumer_subset = memlet.subset();
39✔
68
                if (consumer_subset.size() != producer_subset.size()) {
39✔
69
                    return abort();
1✔
70
                }
1✔
71

72
                // Check if this subset is already in unique_subsets
73
                bool found = false;
38✔
74
                for (const auto& existing : unique_subsets) {
38✔
75
                    if (existing.size() != consumer_subset.size()) continue;
7✔
76
                    bool match = true;
7✔
77
                    for (size_t d = 0; d < existing.size(); ++d) {
9✔
78
                        if (!symbolic::eq(existing[d], consumer_subset[d])) {
7✔
79
                            match = false;
5✔
80
                            break;
5✔
81
                        }
5✔
82
                    }
7✔
83
                    if (match) {
7✔
84
                        found = true;
2✔
85
                        break;
2✔
86
                    }
2✔
87
                }
7✔
88
                if (!found) {
38✔
89
                    unique_subsets.push_back(consumer_subset);
36✔
90
                }
36✔
91
            }
38✔
92
        }
37✔
93
        return false;
33✔
94
    }
34✔
95

96
    bool visit(sdfg::structured_control_flow::Sequence& node) override {
33✔
97
        for (int i = 0; i < node.size(); ++i) {
66✔
98
            if (dispatch(node.at(i).first)) {
34✔
99
                return true;
1✔
100
            }
1✔
101
        }
34✔
102

103
        return false;
32✔
104
    }
33✔
105

NEW
106
    bool visit(IfElse& node) override {
×
NEW
107
        for (int i = 0; i < node.size(); ++i) {
×
NEW
108
            if (visit(node.at(i).first)) {
×
NEW
109
                return true;
×
NEW
110
            }
×
NEW
111
        }
×
112

NEW
113
        return false;
×
NEW
114
    }
×
115
};
116

117
class FusionConsumerUpdateVisitor : public visitor::ActualStructuredSDFGVisitor {
118
    friend MapFusion;
119

120
    builder::StructuredSDFGBuilder& builder_;
121
    const std::vector<MapFusion::FusionCandidate>& fusion_candidates_;
122
    const std::vector<std::string>& candidate_temps_;
123

124
public:
125
    FusionConsumerUpdateVisitor(
126
        builder::StructuredSDFGBuilder& builder,
127
        const std::vector<MapFusion::FusionCandidate>& fusion_candidates,
128
        const std::vector<std::string>& candidate_temps
129
    )
130
        : builder_(builder), fusion_candidates_(fusion_candidates), candidate_temps_(candidate_temps) {}
16✔
131

132
    bool dispatch_partial_sequence(Sequence& node, size_t first, size_t end) {
16✔
133
        for (int i = first; i < end; ++i) {
34✔
134
            if (dispatch(node.at(i).first)) {
18✔
NEW
135
                return true;
×
NEW
136
            }
×
137
        }
18✔
138

139
        return false;
16✔
140
    }
16✔
141

142
    bool visit(sdfg::structured_control_flow::Block& block) override {
18✔
143
        auto& dataflow = block.dataflow();
18✔
144

145
        // Snapshot access nodes before mutation: adding new access nodes below
146
        // would rehash dataflow.nodes_ and invalidate the range iterator.
147
        std::vector<data_flow::AccessNode*> access_nodes;
18✔
148
        for (auto& node : dataflow.nodes()) {
64✔
149
            auto* an = dynamic_cast<data_flow::AccessNode*>(&node);
64✔
150
            if (an != nullptr && dataflow.out_degree(*an) > 0) {
64✔
151
                access_nodes.push_back(an);
26✔
152
            }
26✔
153
        }
64✔
154

155
        for (auto* access : access_nodes) {
26✔
156
            std::string original_container = access->data();
26✔
157

158
            // Match each out-edge against a fusion candidate.
159
            struct Match {
26✔
160
                data_flow::Memlet* memlet;
26✔
161
                size_t cand_idx;
26✔
162
            };
26✔
163
            std::vector<Match> matches;
26✔
164
            for (auto& memlet : dataflow.out_edges(*access)) {
28✔
165
                if (memlet.type() != data_flow::MemletType::Computational) {
28✔
NEW
166
                    continue;
×
NEW
167
                }
×
168
                const auto& memlet_subset = memlet.subset();
28✔
169
                for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
38✔
170
                    auto& candidate = fusion_candidates_[cand_idx];
30✔
171
                    if (original_container != candidate.container) {
30✔
172
                        continue;
8✔
173
                    }
8✔
174
                    if (memlet_subset.size() != candidate.consumer_subset.size()) {
22✔
NEW
175
                        continue;
×
NEW
176
                    }
×
177
                    bool subset_matches = true;
22✔
178
                    for (size_t d = 0; d < memlet_subset.size(); ++d) {
45✔
179
                        if (!symbolic::eq(memlet_subset[d], candidate.consumer_subset[d])) {
25✔
180
                            subset_matches = false;
2✔
181
                            break;
2✔
182
                        }
2✔
183
                    }
25✔
184
                    if (subset_matches) {
22✔
185
                        matches.push_back({&memlet, cand_idx});
20✔
186
                        break;
20✔
187
                    }
20✔
188
                }
22✔
189
            }
28✔
190
            if (matches.empty()) {
26✔
191
                continue;
8✔
192
            }
8✔
193

194
            // Group matches by candidate index.
195
            std::unordered_set<size_t> distinct_cands;
18✔
196
            for (auto& m : matches) {
20✔
197
                distinct_cands.insert(m.cand_idx);
20✔
198
            }
20✔
199

200
            if (distinct_cands.size() == 1) {
18✔
201
                // Fast path: all matched out-edges resolve to the same candidate.
202
                // Mutate the shared access node in place — this preserves the
203
                // existing semantics for the single-read-per-container case.
204
                size_t cand_idx = *distinct_cands.begin();
17✔
205
                const auto& temp_name = candidate_temps_[cand_idx];
17✔
206
                auto& temp_type = builder_.subject().type(temp_name);
17✔
207

208
                access->data(temp_name);
17✔
209

210
                for (auto& m : matches) {
18✔
211
                    m.memlet->set_subset({});
18✔
212
                    m.memlet->set_base_type(temp_type);
18✔
213
                }
18✔
214

215
                for (auto& in_edge : dataflow.in_edges(*access)) {
17✔
NEW
216
                    in_edge.set_subset({});
×
NEW
217
                    in_edge.set_base_type(temp_type);
×
NEW
218
                }
×
219
            } else {
17✔
220
                // Stencil-like case: a single access node feeds reads at
221
                // multiple distinct subsets (e.g. T[j-1] and T[j+1] sharing
222
                // one AccessNode). Each must be rewired to its own
223
                // candidate-specific temp scalar — otherwise mutating
224
                // `access->data()` once per candidate makes all reads
225
                // collapse onto the last temp, e.g. T[j+1]-T[j] becomes
226
                // tmp-tmp == 0.
227
                //
228
                // Fix: for each distinct candidate, create one fresh
229
                // AccessNode for its temp scalar and redirect the matched
230
                // edges from the shared access node to the fresh nodes.
231
                struct PendingRedirect {
1✔
232
                    data_flow::DataFlowNode* dst;
1✔
233
                    std::string src_conn;
1✔
234
                    std::string dst_conn;
1✔
235
                    DebugInfo debug_info;
1✔
236
                    size_t cand_idx;
1✔
237
                    const data_flow::Memlet* memlet_to_remove;
1✔
238
                };
1✔
239
                std::vector<PendingRedirect> pending;
1✔
240
                pending.reserve(matches.size());
1✔
241
                for (auto& m : matches) {
2✔
242
                    pending.push_back(
2✔
243
                        {&m.memlet->dst(),
2✔
244
                         m.memlet->src_conn(),
2✔
245
                         m.memlet->dst_conn(),
2✔
246
                         m.memlet->debug_info(),
2✔
247
                         m.cand_idx,
2✔
248
                         m.memlet}
2✔
249
                    );
2✔
250
                }
2✔
251

252
                std::unordered_map<size_t, data_flow::AccessNode*> per_cand_node;
1✔
253
                for (auto& p : pending) {
2✔
254
                    auto it = per_cand_node.find(p.cand_idx);
2✔
255
                    if (it == per_cand_node.end()) {
2✔
256
                        auto& fresh = builder_.add_access(block, candidate_temps_[p.cand_idx]);
2✔
257
                        it = per_cand_node.emplace(p.cand_idx, &fresh).first;
2✔
258
                    }
2✔
259
                    auto& temp_type = builder_.subject().type(candidate_temps_[p.cand_idx]);
2✔
260
                    builder_.remove_memlet(block, *p.memlet_to_remove);
2✔
261
                    builder_.add_memlet(block, *it->second, p.src_conn, *p.dst, p.dst_conn, {}, temp_type, p.debug_info);
2✔
262
                }
2✔
263

264
                // If the original shared access node now has no edges at all
265
                // it is dangling and should be removed. Keep it if it still
266
                // has out-edges (unmatched reads of the original container)
267
                // or in-edges (writes to the original container).
268
                if (dataflow.out_degree(*access) == 0 && dataflow.in_degree(*access) == 0) {
1✔
269
                    builder_.remove_node(block, *access);
1✔
270
                }
1✔
271
            }
1✔
272
        }
18✔
273
        return false;
18✔
274
    }
18✔
275

NEW
276
    bool visit(sdfg::structured_control_flow::Sequence& node) override {
×
NEW
277
        for (int i = 0; i < node.size(); ++i) {
×
NEW
278
            if (dispatch(node.at(i).first)) {
×
NEW
279
                return true;
×
NEW
280
            }
×
NEW
281
        }
×
282

NEW
283
        return false;
×
NEW
284
    }
×
285

NEW
286
    bool visit(IfElse& node) override {
×
NEW
287
        for (int i = 0; i < node.size(); ++i) {
×
NEW
288
            if (visit(node.at(i).first)) {
×
NEW
289
                return true;
×
NEW
290
            }
×
NEW
291
        }
×
292

NEW
293
        return false;
×
NEW
294
    }
×
295
};
296

297
MapFusion::MapFusion(
298
    structured_control_flow::Map& first_map,
299
    structured_control_flow::StructuredLoop& second_loop,
300
    bool require_consecutive
301
)
302
    : first_map_(first_map), second_loop_(second_loop), require_consecutive_(require_consecutive) {}
43✔
303

304
std::string MapFusion::name() const { return "MapFusion"; }
2✔
305

306
std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> MapFusion::solve_subsets(
307
    const data_flow::Subset& producer_subset,
308
    const data_flow::Subset& consumer_subset,
309
    const std::vector<structured_control_flow::StructuredLoop*>& producer_loops,
310
    const std::vector<structured_control_flow::StructuredLoop*>& consumer_loops,
311
    const symbolic::Assumptions& producer_assumptions,
312
    const symbolic::Assumptions& consumer_assumptions,
313
    bool invert_range_check
314
) {
36✔
315
    // Delinearize subsets to recover multi-dimensional structure from linearized accesses
316
    // e.g. T[i*N + j] with assumptions on bounds -> T[i, j]
317
    auto producer_sub = producer_subset;
36✔
318
    if (producer_sub.size() == 1) {
36✔
319
        auto producer_result = symbolic::delinearize(producer_sub.at(0), producer_assumptions);
27✔
320
        if (producer_result.success) {
27✔
321
            producer_sub = producer_result.indices;
24✔
322
        }
24✔
323
    }
27✔
324
    auto consumer_sub = consumer_subset;
36✔
325
    if (consumer_sub.size() == 1) {
36✔
326
        auto consumer_result = symbolic::delinearize(consumer_sub.at(0), consumer_assumptions);
27✔
327
        if (consumer_result.success) {
27✔
328
            consumer_sub = consumer_result.indices;
24✔
329
        }
24✔
330
    }
27✔
331

332
    // Subset dimensions must match
333
    if (producer_sub.size() != consumer_sub.size()) {
36✔
334
        return {};
×
335
    }
×
336
    if (producer_sub.empty()) {
36✔
337
        return {};
×
338
    }
×
339

340
    // Extract producer indvars
341
    SymEngine::vec_sym producer_vars;
36✔
342
    for (auto* loop : producer_loops) {
46✔
343
        producer_vars.push_back(SymEngine::rcp_static_cast<const SymEngine::Symbol>(loop->indvar()));
46✔
344
    }
46✔
345

346
    // Step 1: Solve the linear equation system using SymEngine
347
    // System: producer_sub[d] - consumer_sub[d] = 0, for each dimension d
348
    // Solve for producer_vars in terms of consumer_vars and parameters
349
    SymEngine::vec_basic equations;
36✔
350
    for (size_t d = 0; d < producer_sub.size(); ++d) {
82✔
351
        equations.push_back(symbolic::sub(producer_sub.at(d), consumer_sub.at(d)));
46✔
352
    }
46✔
353

354
    // Need exactly as many equations as unknowns for a unique solution.
355
    // Underdetermined systems (e.g. linearized access with multiple loop vars)
356
    // cannot be uniquely solved and would crash linsolve.
357
    if (equations.size() != producer_vars.size()) {
36✔
358
        return {};
×
359
    }
×
360

361
    SymEngine::vec_basic solution;
36✔
362
    try {
36✔
363
        solution = SymEngine::linsolve(equations, producer_vars);
36✔
364
    } catch (...) {
36✔
365
        return {};
×
366
    }
×
367
    if (solution.size() != producer_vars.size()) {
36✔
368
        return {};
×
369
    }
×
370
    // Build consumer var set for atom validation
371
    symbolic::SymbolSet consumer_var_set;
36✔
372
    for (auto* loop : consumer_loops) {
47✔
373
        consumer_var_set.insert(loop->indvar());
47✔
374
    }
47✔
375

376
    std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> mappings;
36✔
377
    for (size_t i = 0; i < producer_vars.size(); ++i) {
82✔
378
        auto& sol = solution[i];
46✔
379

380
        // Check for invalid solutions
381
        if (SymEngine::is_a<SymEngine::NaN>(*sol) || SymEngine::is_a<SymEngine::Infty>(*sol)) {
46✔
382
            return {};
×
383
        }
×
384

385
        // Validate that solution atoms are consumer vars or parameters
386
        for (const auto& atom : symbolic::atoms(sol)) {
48✔
387
            if (consumer_var_set.count(atom)) {
48✔
388
                continue;
48✔
389
            }
48✔
390
            bool is_param = false;
×
391
            auto it = consumer_assumptions.find(atom);
×
392
            if (it != consumer_assumptions.end() && it->second.constant()) {
×
393
                is_param = true;
×
394
            }
×
395
            if (!is_param) {
×
396
                it = producer_assumptions.find(atom);
×
397
                if (it != producer_assumptions.end() && it->second.constant()) {
×
398
                    is_param = true;
×
399
                }
×
400
            }
×
401
            if (!is_param) {
×
402
                return {};
×
403
            }
×
404
        }
×
405

406
        mappings.push_back({symbolic::symbol(producer_vars[i]->get_name()), symbolic::expand(sol)});
46✔
407
    }
46✔
408
    // Step 2: ISL integrality validation via map composition
409
    // Build an unconstrained producer access map (no domain bounds on producer vars).
410
    // In map fusion, the producer's computation is inlined into the consumer, so
411
    // the producer's original iteration domain is irrelevant. We only need to verify
412
    // that the equation system has an INTEGER solution for every consumer point.
413
    symbolic::Assumptions unconstrained_producer;
36✔
414
    for (auto* loop : producer_loops) {
46✔
415
        symbolic::Assumption a(loop->indvar());
46✔
416
        a.constant(false);
46✔
417
        unconstrained_producer[loop->indvar()] = a;
46✔
418
    }
46✔
419
    for (const auto& [sym, assump] : producer_assumptions) {
140✔
420
        if (assump.constant() && unconstrained_producer.find(sym) == unconstrained_producer.end()) {
140✔
421
            unconstrained_producer[sym] = assump;
46✔
422
        }
46✔
423
    }
140✔
424

425
    std::string producer_map_str = symbolic::expression_to_map_str(producer_sub, unconstrained_producer);
36✔
426
    // Build consumer access map with full domain constraints
427
    std::string consumer_map_str = symbolic::expression_to_map_str(consumer_sub, consumer_assumptions);
36✔
428

429
    isl_ctx* ctx = isl_ctx_alloc();
36✔
430
    isl_options_set_on_error(ctx, ISL_ON_ERROR_CONTINUE);
36✔
431

432
    isl_map* producer_map = isl_map_read_from_str(ctx, producer_map_str.c_str());
36✔
433
    isl_map* consumer_map = isl_map_read_from_str(ctx, consumer_map_str.c_str());
36✔
434

435
    if (!producer_map || !consumer_map) {
36✔
436
        if (producer_map) isl_map_free(producer_map);
×
437
        if (consumer_map) isl_map_free(consumer_map);
×
438
        isl_ctx_free(ctx);
×
439
        return {};
×
440
    }
×
441

442
    // Align parameters between the two maps
443
    isl_space* params_p = isl_space_params(isl_map_get_space(producer_map));
36✔
444
    isl_space* params_c = isl_space_params(isl_map_get_space(consumer_map));
36✔
445
    isl_space* unified = isl_space_align_params(isl_space_copy(params_p), isl_space_copy(params_c));
36✔
446
    isl_space_free(params_p);
36✔
447
    isl_space_free(params_c);
36✔
448

449
    producer_map = isl_map_align_params(producer_map, isl_space_copy(unified));
36✔
450
    consumer_map = isl_map_align_params(consumer_map, isl_space_copy(unified));
36✔
451

452
    // Save consumer domain before consuming consumer_map in composition
453
    isl_set* consumer_domain = isl_map_domain(isl_map_copy(consumer_map));
36✔
454

455
    // Compute composition: consumer_access ∘ inverse(producer_access)
456
    // This checks whether the equation system producer_subset = consumer_subset
457
    // has an integer solution for each consumer domain point.
458
    isl_map* producer_inverse = isl_map_reverse(producer_map);
36✔
459
    isl_map* composition = isl_map_apply_range(consumer_map, producer_inverse);
36✔
460

461
    // Check single-valuedness: each consumer point maps to at most one producer point
462
    bool single_valued = isl_map_is_single_valued(composition) == isl_bool_true;
36✔
463

464
    // Check domain coverage: every consumer point has a valid integer mapping
465
    isl_set* comp_domain = isl_map_domain(composition);
36✔
466

467
    bool domain_covered = isl_set_is_subset(consumer_domain, comp_domain) == isl_bool_true;
36✔
468

469
    isl_set_free(comp_domain);
36✔
470
    isl_set_free(consumer_domain);
36✔
471

472
    // Step 3: Verify producer write range covers consumer read range.
473
    // The producer only writes a subset of the array if its loops have restricted bounds.
474
    // Fusion is invalid if the consumer reads elements the producer never writes.
475
    bool range_covered = false;
36✔
476
    if (single_valued && domain_covered) {
36✔
477
        std::string constrained_producer_map_str = symbolic::expression_to_map_str(producer_sub, producer_assumptions);
34✔
478
        isl_map* constrained_producer = isl_map_read_from_str(ctx, constrained_producer_map_str.c_str());
34✔
479
        isl_map* consumer_map_copy = isl_map_read_from_str(ctx, consumer_map_str.c_str());
34✔
480

481
        if (constrained_producer && consumer_map_copy) {
34✔
482
            constrained_producer = isl_map_align_params(constrained_producer, isl_space_copy(unified));
34✔
483
            consumer_map_copy = isl_map_align_params(consumer_map_copy, isl_space_copy(unified));
34✔
484

485
            isl_set* producer_range = isl_map_range(constrained_producer);
34✔
486
            isl_set* consumer_range = isl_map_range(consumer_map_copy);
34✔
487

488
            // When arguments are swapped (ConsumerIntoProducer), the "producer"/"consumer"
489
            // labels are inverted. Flip the subset check to always verify:
490
            // actual_consumer_read_range ⊆ actual_producer_write_range
491
            if (invert_range_check) {
34✔
492
                range_covered = isl_set_is_subset(producer_range, consumer_range) == isl_bool_true;
4✔
493
            } else {
30✔
494
                range_covered = isl_set_is_subset(consumer_range, producer_range) == isl_bool_true;
30✔
495
            }
30✔
496

497
            isl_set_free(producer_range);
34✔
498
            isl_set_free(consumer_range);
34✔
499
        } else {
34✔
500
            if (constrained_producer) isl_map_free(constrained_producer);
×
501
            if (consumer_map_copy) isl_map_free(consumer_map_copy);
×
502
        }
×
503
    }
34✔
504

505
    isl_space_free(unified);
36✔
506
    isl_ctx_free(ctx);
36✔
507

508
    if (!single_valued || !domain_covered || !range_covered) {
36✔
509
        return {};
5✔
510
    }
5✔
511

512
    return mappings;
31✔
513
}
36✔
514

515
bool MapFusion::find_write_location(
516
    structured_control_flow::StructuredLoop& loop,
517
    const std::string& container,
518
    std::vector<structured_control_flow::StructuredLoop*>& loops,
519
    structured_control_flow::Sequence*& body,
520
    structured_control_flow::Block*& block
521
) {
4✔
522
    loops.push_back(&loop);
4✔
523
    auto& seq = loop.root();
4✔
524

525
    for (size_t i = 0; i < seq.size(); ++i) {
10✔
526
        auto& child = seq.at(i).first;
6✔
527

528
        if (auto* blk = dynamic_cast<structured_control_flow::Block*>(&child)) {
6✔
529
            // Check if this block writes to the container
530
            auto& dataflow = blk->dataflow();
4✔
531
            for (auto& node : dataflow.nodes()) {
14✔
532
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
14✔
533
                if (access == nullptr || access->data() != container) {
14✔
534
                    continue;
12✔
535
                }
12✔
536
                // Write access: has incoming edges (sink node)
537
                if (dataflow.in_degree(*access) > 0 && dataflow.out_degree(*access) == 0) {
2✔
538
                    if (block != nullptr) {
2✔
539
                        // Multiple write blocks found — ambiguous
540
                        return false;
×
541
                    }
×
542
                    body = &seq;
2✔
543
                    block = blk;
2✔
544
                }
2✔
545
            }
2✔
546
        } else if (auto* nested_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(&child)) {
4✔
547
            if (!find_write_location(*nested_loop, container, loops, body, block)) {
2✔
548
                return false;
×
549
            }
×
550
            // If we didn't find the write in this subtree, pop the loop back off
551
            if (loops.back() != &loop && block == nullptr) {
2✔
552
                // The recursive call already popped — but we need to check
553
            }
×
554
        }
2✔
555
    }
6✔
556

557
    // If we didn't find the write in this subtree, remove this loop from the chain
558
    if (block == nullptr) {
4✔
559
        loops.pop_back();
×
560
    }
×
561

562
    return true;
4✔
563
}
4✔
564

565
bool MapFusion::find_read_location(
566
    structured_control_flow::StructuredLoop& loop,
567
    const std::string& container,
568
    std::vector<structured_control_flow::StructuredLoop*>& loops,
569
    structured_control_flow::Sequence*& body
570
) {
2✔
571
    loops.push_back(&loop);
2✔
572
    auto& seq = loop.root();
2✔
573

574
    for (size_t i = 0; i < seq.size(); ++i) {
5✔
575
        auto& child = seq.at(i).first;
3✔
576

577
        if (auto* blk = dynamic_cast<structured_control_flow::Block*>(&child)) {
3✔
578
            // Check if this block reads from the container
579
            auto& dataflow = blk->dataflow();
2✔
580
            for (auto& node : dataflow.nodes()) {
8✔
581
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
8✔
582
                if (access == nullptr || access->data() != container) {
8✔
583
                    continue;
7✔
584
                }
7✔
585
                // Read access: has outgoing edges (source node)
586
                if (dataflow.in_degree(*access) == 0 && dataflow.out_degree(*access) > 0) {
1✔
587
                    if (body != nullptr && body != &seq) {
1✔
588
                        // Reads at different sequence levels — ambiguous
589
                        return false;
×
590
                    }
×
591
                    body = &seq;
1✔
592
                }
1✔
593
            }
1✔
594
        } else if (auto* nested_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(&child)) {
2✔
595
            if (!find_read_location(*nested_loop, container, loops, body)) {
1✔
596
                return false;
×
597
            }
×
598
        }
1✔
599
    }
3✔
600

601
    // If we didn't find any reads in this subtree, remove this loop from the chain
602
    if (body == nullptr) {
2✔
603
        loops.pop_back();
×
604
    }
×
605

606
    return true;
2✔
607
}
2✔
608

609
bool MapFusion::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
41✔
610
    fusion_candidates_.clear();
41✔
611

612
    // no use in fusing empty loops. Also presumed to not be empty further down
613
    if (first_map_.root().size() == 0 || second_loop_.root().size() == 0) {
41✔
614
        return false;
×
615
    }
×
616

617
    // Criterion: Get parent scope and verify both loops are sequential children
618
    auto* first_parent = first_map_.get_parent();
41✔
619
    auto* second_parent = second_loop_.get_parent();
41✔
620
    if (first_parent == nullptr || second_parent == nullptr) {
41✔
621
        return false;
×
622
    }
×
623
    if (first_parent != second_parent) {
41✔
624
        return false;
×
625
    }
×
626

627
    auto* parent_sequence = dynamic_cast<structured_control_flow::Sequence*>(first_parent);
41✔
628
    if (parent_sequence == nullptr) {
41✔
629
        return false;
×
630
    }
×
631

632
    int first_index = parent_sequence->index(first_map_);
41✔
633
    int second_index = parent_sequence->index(second_loop_);
41✔
634
    if (first_index == -1 || second_index == -1) {
41✔
635
        return false;
×
636
    }
×
637
    if (require_consecutive_ && second_index != first_index + 1) {
41✔
638
        return false;
1✔
639
    }
1✔
640

641
    // Criterion: Transition between maps should have no assignments
642
    if (require_consecutive_) {
40✔
643
        auto& transition = parent_sequence->at(first_index).second;
40✔
644
        if (!transition.empty()) {
40✔
645
            return false;
×
646
        }
×
647
    }
40✔
648
    // Determine fusion pattern based on nesting properties
649
    auto& loop_analysis = analysis_manager.get<analysis::LoopAnalysis>();
40✔
650
    auto first_loop_info = loop_analysis.loop_info(&first_map_);
40✔
651
    auto second_loop_info = loop_analysis.loop_info(&second_loop_);
40✔
652

653
    auto limit_depth = 0;
40✔
654

655
    bool first_nested = first_loop_info.is_perfectly_nested;
40✔
656
    bool second_nested = second_loop_info.is_perfectly_nested;
40✔
657

658
    // Both non-perfectly-nested: not supported
659
    if (!first_nested && !second_nested) {
40✔
660
        return false;
1✔
661
    }
1✔
662

663
    if (first_nested && second_nested) {
39✔
664
        // Pattern 1: Both perfectly nested — producer into consumer (original path)
665
        direction_ = FusionDirection::ProducerIntoConsumer;
36✔
666
    } else if (!first_nested && second_nested) {
36✔
667
        // Pattern 2: Producer non-perfectly-nested, consumer perfectly nested
668
        direction_ = FusionDirection::ConsumerIntoProducer;
2✔
669
    } else {
2✔
670
        // Reverse Pattern 2: Producer perfectly nested, consumer non-perfectly-nested
671
        direction_ = FusionDirection::ProducerIntoConsumer;
1✔
672
    }
1✔
673

674
    // The side being inlined must be all-parallel (all Maps) so iterations can be reordered.
675
    // ProducerIntoConsumer: both sides must be all-parallel. The producer is replicated at
676
    // each consumer site — it must be reorderable. The consumer must also be all-parallel
677
    // because a sequential (For) loop would re-execute the inlined producer on every
678
    // iteration (e.g. init T=0 fused into For(k){T+=A[k]} re-initializes each k).
679
    // ConsumerIntoProducer: only the consumer (inlined side) must be all-parallel.
680
    if (direction_ == FusionDirection::ProducerIntoConsumer) {
39✔
681
        if (!first_loop_info.is_perfectly_parallel) {
37✔
NEW
682
            return false;
×
683
        } else if (!second_loop_info.is_perfectly_parallel) {
37✔
684
            if (second_loop_info.is_perfectly_nested) { // we can check if the innermost loop is non parallel, but the
2✔
685
                                                        // outers are
686
            }
2✔
687
            return false;
2✔
688
        }
2✔
689
    } else {
37✔
690
        if (!second_loop_info.is_perfectly_parallel) {
2✔
NEW
691
            return false;
×
NEW
692
        }
×
693
    }
2✔
694

695
    // Locate producer write point
696
    producer_loops_.clear();
37✔
697
    producer_body_ = nullptr;
37✔
698
    producer_block_ = nullptr;
37✔
699

700
    if (first_nested) {
37✔
701
        // Perfectly nested: walk the at(0).first chain
702
        producer_loops_.push_back(&first_map_);
35✔
703
        producer_body_ = &first_map_.root();
35✔
704
        structured_control_flow::ControlFlowNode* node = &first_map_.root().at(0).first;
35✔
705
        int level = 1;
35✔
706
        while (auto* nested = dynamic_cast<structured_control_flow::StructuredLoop*>(node)) {
44✔
707
            if (limit_depth && ++level > limit_depth) {
10✔
NEW
708
                break;
×
NEW
709
            }
×
710
            producer_loops_.push_back(nested);
10✔
711
            producer_body_ = &nested->root();
10✔
712
            if (nested->root().size() == 0) return false;
10✔
713
            node = &nested->root().at(0).first;
9✔
714
        }
9✔
715
        producer_block_ = dynamic_cast<structured_control_flow::Block*>(node);
34✔
716
        if (producer_block_ == nullptr) {
34✔
717
            return false;
×
718
        }
×
719
        // If the body has multiple children, the at(0) walk does not guarantee
720
        // we found the correct (or unique) write block. Fall back to deferred
721
        // find_write_location resolution.
722
        if (producer_body_->size() != 1) {
34✔
723
            producer_block_ = nullptr;
×
724
            // Keep producer_loops_ and producer_body_ from the walk — they are
725
            // still valid for the loop chain. find_write_location will re-resolve
726
            // the block within producer_body_.
727
        }
×
728
    } else {
34✔
729
        // Non-perfectly-nested: search recursively for the write block
730
        // We need to know which containers to look for, but we don't know them yet.
731
        // Defer write location search until after fusion_containers are identified.
732
    }
2✔
733

734
    // Locate consumer read point
735
    consumer_loops_.clear();
36✔
736
    consumer_body_ = nullptr;
36✔
737

738
    if (second_nested) {
36✔
739
        // Perfectly nested: walk the at(0).first chain through all loop types.
740
        // Reduction patterns (e.g. Map{Map{For{T[i,j]+=...}}}) are rejected by
741
        // the is_perfectly_parallel check — For loops make it non-parallel.
742
        consumer_loops_.push_back(&second_loop_);
35✔
743
        consumer_body_ = &second_loop_.root();
35✔
744
        structured_control_flow::ControlFlowNode* node = &second_loop_.root().at(0).first;
35✔
745
        int level = 1;
35✔
746
        while (auto* nested = dynamic_cast<structured_control_flow::StructuredLoop*>(node)) {
45✔
747
            if (limit_depth && ++level > limit_depth) {
11✔
NEW
748
                break;
×
NEW
749
            }
×
750
            consumer_loops_.push_back(nested);
11✔
751
            consumer_body_ = &nested->root();
11✔
752
            if (nested->root().size() == 0) return false;
11✔
753
            node = &nested->root().at(0).first;
10✔
754
        }
10✔
755
    } else {
35✔
756
        // Non-perfectly-nested: defer read location search until after fusion_containers are identified.
757
    }
1✔
758

759
    // Get arguments analysis to identify inputs/outputs of each loop
760
    auto& arguments_analysis = analysis_manager.get<analysis::ArgumentsAnalysis>();
35✔
761
    auto first_args = arguments_analysis.arguments(analysis_manager, first_map_);
35✔
762
    auto second_args = arguments_analysis.arguments(analysis_manager, second_loop_);
35✔
763

764
    std::unordered_set<std::string> first_inputs;
35✔
765
    std::unordered_set<std::string> first_outputs;
35✔
766
    for (const auto& [name, arg] : first_args) {
114✔
767
        if (arg.is_output) {
114✔
768
            first_outputs.insert(name);
36✔
769
        }
36✔
770
        if (arg.is_input) {
114✔
771
            first_inputs.insert(name);
114✔
772
        }
114✔
773
    }
114✔
774

775
    // First pass: identify fusion containers (producer writes, consumer reads)
776
    std::unordered_set<std::string> fusion_containers;
35✔
777
    for (const auto& [name, arg] : second_args) {
120✔
778
        if (first_outputs.contains(name) && arg.is_input) {
120✔
779
            fusion_containers.insert(name);
34✔
780
        }
34✔
781
    }
120✔
782
    if (fusion_containers.empty()) {
35✔
783
        return false;
1✔
784
    }
1✔
785

786
    // Second pass: check for conflicts on non-fusion containers
787
    for (const auto& [name, arg] : second_args) {
115✔
788
        bool is_fusion = fusion_containers.contains(name);
115✔
789
        if (first_outputs.contains(name) && arg.is_output && !is_fusion) {
115✔
790
            return false;
×
791
        }
×
792
        if (first_inputs.contains(name) && arg.is_output && !is_fusion) {
115✔
793
            return false;
1✔
794
        }
1✔
795
    }
115✔
796

797
    // Now that we know the fusion containers, resolve deferred locations
798
    if (producer_block_ == nullptr) {
33✔
799
        // Non-perfectly-nested producer (or perfectly-nested with multi-block body):
800
        // find write location for the first fusion container.
801
        // All fusion containers must be written at the same block for this to work.
802
        for (const auto& container : fusion_containers) {
2✔
803
            std::vector<structured_control_flow::StructuredLoop*> write_loops;
2✔
804
            structured_control_flow::Sequence* write_body = nullptr;
2✔
805
            structured_control_flow::Block* write_block = nullptr;
2✔
806

807
            if (!find_write_location(first_map_, container, write_loops, write_body, write_block)) {
2✔
808
                return false;
×
809
            }
×
810
            if (write_block == nullptr) {
2✔
811
                return false;
×
812
            }
×
813

814
            if (producer_block_ == nullptr) {
2✔
815
                // First container: set the locations
816
                producer_loops_ = write_loops;
2✔
817
                producer_body_ = write_body;
2✔
818
                producer_block_ = write_block;
2✔
819
            } else {
2✔
820
                // Subsequent containers must be in the same block
821
                if (write_block != producer_block_) {
×
822
                    return false;
×
823
                }
×
824
            }
×
825
        }
2✔
826
    }
2✔
827

828
    if (!second_nested) {
33✔
829
        // Non-perfectly-nested consumer: find read location for the first fusion container
830
        // All fusion containers must be read at the same sequence for this to work
831
        for (const auto& container : fusion_containers) {
1✔
832
            std::vector<structured_control_flow::StructuredLoop*> read_loops;
1✔
833
            structured_control_flow::Sequence* read_body = nullptr;
1✔
834

835
            if (!find_read_location(second_loop_, container, read_loops, read_body)) {
1✔
836
                return false;
×
837
            }
×
838
            if (read_body == nullptr) {
1✔
839
                return false;
×
840
            }
×
841

842
            if (consumer_body_ == nullptr) {
1✔
843
                // First container: set the locations
844
                consumer_loops_ = read_loops;
1✔
845
                consumer_body_ = read_body;
1✔
846
            } else {
1✔
847
                // Subsequent containers must be at the same sequence
848
                if (read_body != consumer_body_) {
×
849
                    return false;
×
850
                }
×
851
            }
×
852
        }
1✔
853
    }
1✔
854

855
    // Get assumptions for the resolved write/read locations
856
    // Include trivial bounds from types to help delinearization with symbolic strides
857
    auto& assumptions_analysis = analysis_manager.get<analysis::AssumptionsAnalysis>();
33✔
858
    auto& producer_assumptions = assumptions_analysis.get(*producer_block_, true);
33✔
859
    auto& consumer_assumptions = assumptions_analysis.get(consumer_body_->at(0).first, true);
33✔
860

861
    // Check if producer actually reads a fusion container in the dataflow.
862
    // If so, ProducerIntoConsumer is unsafe (original producer loop mutates the array
863
    // before the inlined copy reads it). Force ConsumerIntoProducer.
864
    // We check the dataflow directly rather than ArgumentsAnalysis, because the latter
865
    // conservatively marks written containers as also read.
866
    if (direction_ == FusionDirection::ProducerIntoConsumer) {
33✔
867
        auto& first_dataflow_check = producer_block_->dataflow();
31✔
868
        bool producer_reads_fusion = false;
31✔
869
        for (const auto& container : fusion_containers) {
31✔
870
            for (auto& node : first_dataflow_check.nodes()) {
99✔
871
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
99✔
872
                if (access != nullptr && access->data() == container && first_dataflow_check.out_degree(*access) > 0) {
99✔
873
                    producer_reads_fusion = true;
2✔
874
                    break;
2✔
875
                }
2✔
876
            }
99✔
877
            if (producer_reads_fusion) break;
31✔
878
        }
31✔
879
        if (producer_reads_fusion) {
31✔
880
            direction_ = FusionDirection::ConsumerIntoProducer;
2✔
881
            // Re-check: consumer must be all-parallel for ConsumerIntoProducer
882
            if (!second_loop_info.is_perfectly_parallel) {
2✔
883
                return false;
×
884
            }
×
885
        }
2✔
886
    }
31✔
887

888
    // ProducerIntoConsumer only deep-copies producer_block_ into the consumer body.
889
    // If the producer body has multiple blocks (e.g. from prior BlockFusion merging
890
    // a previous fusion's writeback + inlined blocks), the write block may depend on
891
    // intermediates produced by earlier blocks that would NOT be copied. Reject.
892
    if (direction_ == FusionDirection::ProducerIntoConsumer && producer_body_->size() > 1) {
33✔
893
        return false;
×
894
    }
×
895

896
    std::unordered_map<std::string, const data_flow::Subset*> producer_subsets;
33✔
897

898
    // For each fusion container, find the producer memlet and collect unique consumer subsets
899
    auto& first_dataflow = producer_block_->dataflow();
33✔
900
    for (const auto& container : fusion_containers) {
33✔
901
        // Find unique producer write in first map
902
        data_flow::Memlet* producer_memlet = nullptr;
33✔
903

904
        for (auto& node : first_dataflow.nodes()) {
106✔
905
            auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
106✔
906
            if (access == nullptr || access->data() != container) {
106✔
907
                continue;
71✔
908
            }
71✔
909
            // Skip read-only access nodes (producer reads the fusion container)
910
            if (first_dataflow.in_degree(*access) == 0) {
35✔
911
                continue;
2✔
912
            }
2✔
913
            // Write access: must have exactly one incoming edge and no outgoing
914
            if (first_dataflow.in_degree(*access) != 1 || first_dataflow.out_degree(*access) != 0) {
33✔
915
                return false;
×
916
            }
×
917
            auto& iedge = *first_dataflow.in_edges(*access).begin();
33✔
918
            if (iedge.type() != data_flow::MemletType::Computational) {
33✔
919
                return false;
×
920
            }
×
921
            if (producer_memlet != nullptr) {
33✔
922
                return false;
×
923
            }
×
924
            producer_memlet = &iedge;
33✔
925
        }
33✔
926
        if (producer_memlet == nullptr) {
33✔
927
            return false;
×
928
        }
×
929

930
        const auto& producer_subset = producer_memlet->subset();
33✔
931
        if (producer_subset.empty()) {
33✔
932
            return false;
×
933
        } else {
33✔
934
            producer_subsets.emplace(container, &producer_subset);
33✔
935
        }
33✔
936
    }
33✔
937

938
    FusionConsumerSubsetVisitor consumer_visitor(producer_subsets);
33✔
939
    bool abort = consumer_visitor.dispatch(*consumer_body_);
33✔
940
    if (abort) {
33✔
941
        return false;
1✔
942
    }
1✔
943

944
    for (auto [container, unique_subsets] : consumer_visitor.unique_subsets_per_container_) {
32✔
945
        auto& producer_subset = *producer_subsets.at(container);
32✔
946
        // For each unique consumer subset, solve index mappings and create a FusionCandidate
947
        // The direction determines which side's indvars are solved for
948
        for (const auto& consumer_subset : unique_subsets) {
36✔
949
            std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> mappings;
36✔
950

951
            if (direction_ == FusionDirection::ProducerIntoConsumer) {
36✔
952
                // Solve producer indvars in terms of consumer indvars
953
                mappings = solve_subsets(
32✔
954
                    producer_subset,
32✔
955
                    consumer_subset,
32✔
956
                    producer_loops_,
32✔
957
                    consumer_loops_,
32✔
958
                    producer_assumptions,
32✔
959
                    consumer_assumptions
32✔
960
                );
32✔
961
            } else {
32✔
962
                // ConsumerIntoProducer: solve consumer indvars in terms of producer indvars
963
                // Arguments are swapped, so invert the range check direction
964
                mappings = solve_subsets(
4✔
965
                    consumer_subset,
4✔
966
                    producer_subset,
4✔
967
                    consumer_loops_,
4✔
968
                    producer_loops_,
4✔
969
                    consumer_assumptions,
4✔
970
                    producer_assumptions,
4✔
971
                    true
4✔
972
                );
4✔
973
            }
4✔
974

975
            if (mappings.empty()) {
36✔
976
                return false;
5✔
977
            }
5✔
978

979
            FusionCandidate candidate;
31✔
980
            candidate.container = container;
31✔
981
            candidate.consumer_subset = consumer_subset;
31✔
982
            candidate.index_mappings = std::move(mappings);
31✔
983

984
            fusion_candidates_.push_back(candidate);
31✔
985
        }
31✔
986
    }
32✔
987

988
    // Criterion: At least one valid fusion candidate
989
    return !fusion_candidates_.empty();
27✔
990
}
32✔
991

992
void MapFusion::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
19✔
993
    auto& sdfg = builder.subject();
19✔
994

995
    if (direction_ == FusionDirection::ProducerIntoConsumer) {
19✔
996
        // Pattern 1 + Reverse Pattern 2: Inline producer blocks into consumer's read body
997
        auto& first_dataflow = producer_block_->dataflow();
16✔
998

999
        // For each fusion candidate, create a temp and insert a producer block
1000
        std::vector<std::string> candidate_temps;
16✔
1001

1002
        for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
34✔
1003
            auto& candidate = fusion_candidates_[cand_idx];
18✔
1004

1005
            auto& container_type = sdfg.type(candidate.container);
18✔
1006
            std::string temp_name = builder.find_new_name("_fused_tmp");
18✔
1007
            types::Scalar tmp_type(container_type.primitive_type());
18✔
1008
            builder.add_container(temp_name, tmp_type);
18✔
1009
            candidate_temps.push_back(temp_name);
18✔
1010

1011
            // Insert a producer block at the beginning of the consumer's body
1012
            auto& first_child = consumer_body_->at(0).first;
18✔
1013
            control_flow::Assignments empty_assignments;
18✔
1014
            auto& new_block = builder.add_block_before(*consumer_body_, first_child, empty_assignments);
18✔
1015
            structured_control_flow::Block* empty_block = nullptr;
18✔
1016

1017
            // Deep copy all nodes from producer block to new block
1018
            std::unordered_map<const data_flow::DataFlowNode*, data_flow::DataFlowNode*> node_mapping;
18✔
1019
            std::unordered_map<std::string, std::string> intermediate_renames;
18✔
1020
            for (auto& node : first_dataflow.nodes()) {
59✔
1021
                node_mapping[&node] = &builder.copy_node(new_block, node);
59✔
1022
                auto* copied = node_mapping[&node];
59✔
1023
                if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(copied)) {
59✔
1024
                    if (access_node->data() == candidate.container) {
41✔
1025
                        access_node->data(temp_name);
18✔
1026
                    } else if (access_node->data() == first_map_.indvar()->get_name()) {
23✔
1027
                        // Determine the new expression for the index variable of the first map
1028
                        symbolic::Expression new_expr = SymEngine::null;
2✔
1029
                        for (auto& c : fusion_candidates_) {
2✔
1030
                            for (auto& [sym, expr] : c.index_mappings) {
2✔
1031
                                if (symbolic::eq(sym, first_map_.indvar())) {
2✔
1032
                                    new_expr = expr;
2✔
1033
                                    break;
2✔
1034
                                }
2✔
1035
                            }
2✔
1036
                            if (!new_expr.is_null()) {
2✔
1037
                                break;
2✔
1038
                            }
2✔
1039
                        }
2✔
1040

1041
                        if (new_expr.is_null() || symbolic::eq(new_expr, second_loop_.indvar())) {
2✔
1042
                            // Simple case: The new expression is simply the index variable of the second loop
1043
                            access_node->data(second_loop_.indvar()->get_name());
1✔
1044
                        } else {
1✔
1045
                            // Complex case: Add an empty block before the new block (if necessary) and store the
1046
                            // shifted index into a new temporary variable with an assignment. Then, replace the index
1047
                            // variable with the new temporary variable
1048
                            if (!empty_block) {
1✔
1049
                                empty_block = &builder.add_block_before(*consumer_body_, new_block, empty_assignments);
1✔
1050
                            }
1✔
1051
                            auto new_index_name = builder.find_new_name();
1✔
1052
                            builder
1✔
1053
                                .add_container(new_index_name, builder.subject().type(second_loop_.indvar()->get_name()));
1✔
1054
                            consumer_body_->at(0)
1✔
1055
                                .second.assignments()
1✔
1056
                                .insert({symbolic::symbol(new_index_name), new_expr});
1✔
1057
                            access_node->data(new_index_name);
1✔
1058
                        }
1✔
1059
                    } else if (first_dataflow.in_degree(node) > 0 && first_dataflow.out_degree(node) > 0 &&
21✔
1060
                               dynamic_cast<const types::Scalar*>(&sdfg.type(access_node->data())) != nullptr) {
21✔
1061
                        // SSA Dataflow required to check for non-local use of the access node's container.
1062
                        // Intermediate access node (e.g. from a prior BlockFusion): clone
1063
                        // its container so each inlined copy gets its own private scalar
1064
                        auto it = intermediate_renames.find(access_node->data());
×
1065
                        if (it == intermediate_renames.end()) {
×
1066
                            std::string fresh = builder.find_new_name(access_node->data());
×
1067
                            builder.add_container(fresh, sdfg.type(access_node->data()));
×
1068
                            intermediate_renames[access_node->data()] = fresh;
×
1069
                        }
×
1070
                        access_node->data(intermediate_renames[access_node->data()]);
×
1071
                    }
×
1072
                }
41✔
1073
            }
59✔
1074

1075
            // Add memlets with index substitution (producer indvars → consumer expressions)
1076
            for (auto& edge : first_dataflow.edges()) {
41✔
1077
                auto& src_node = edge.src();
41✔
1078
                auto& dst_node = edge.dst();
41✔
1079

1080
                const types::IType* base_type = &edge.base_type();
41✔
1081
                data_flow::Subset new_subset;
41✔
1082
                for (const auto& dim : edge.subset()) {
41✔
1083
                    auto new_dim = dim;
39✔
1084
                    for (const auto& [pvar, mapping] : candidate.index_mappings) {
52✔
1085
                        new_dim = symbolic::subs(new_dim, pvar, mapping);
52✔
1086
                    }
52✔
1087
                    new_dim = symbolic::expand(new_dim);
39✔
1088
                    new_subset.push_back(new_dim);
39✔
1089
                }
39✔
1090

1091
                // For output edges to temp scalar, use empty subset
1092
                auto* dst_access = dynamic_cast<data_flow::AccessNode*>(&dst_node);
41✔
1093
                if (dst_access != nullptr && dst_access->data() == candidate.container &&
41✔
1094
                    first_dataflow.in_degree(*dst_access) > 0) {
41✔
1095
                    new_subset.clear();
18✔
1096
                    base_type = &tmp_type;
18✔
1097
                }
18✔
1098

1099
                builder.add_memlet(
41✔
1100
                    new_block,
41✔
1101
                    *node_mapping[&src_node],
41✔
1102
                    edge.src_conn(),
41✔
1103
                    *node_mapping[&dst_node],
41✔
1104
                    edge.dst_conn(),
41✔
1105
                    new_subset,
41✔
1106
                    *base_type,
41✔
1107
                    edge.debug_info()
41✔
1108
                );
41✔
1109
            }
41✔
1110
        }
18✔
1111

1112
        // Update all read accesses in consumer blocks to point to the appropriate temp
1113
        size_t num_producer_blocks = fusion_candidates_.size();
16✔
1114

1115
        FusionConsumerUpdateVisitor update_visitor(builder, fusion_candidates_, candidate_temps);
16✔
1116
        update_visitor.dispatch_partial_sequence(*consumer_body_, num_producer_blocks, consumer_body_->size());
16✔
1117

1118
    } else {
16✔
1119
        // ConsumerIntoProducer (Pattern 2): Inline consumer blocks into the producer's write body
1120
        // Modify the producer block in-place to write to a temp scalar, add a writeback block
1121
        // for the original array, then copy consumer blocks reading from the temp.
1122

1123
        std::vector<std::string> candidate_temps;
3✔
1124
        auto& producer_dataflow = producer_block_->dataflow();
3✔
1125

1126
        for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
6✔
1127
            auto& candidate = fusion_candidates_[cand_idx];
3✔
1128

1129
            auto& container_type = sdfg.type(candidate.container);
3✔
1130
            std::string temp_name = builder.find_new_name("_fused_tmp");
3✔
1131
            types::Scalar tmp_type(container_type.primitive_type());
3✔
1132
            builder.add_container(temp_name, tmp_type);
3✔
1133
            candidate_temps.push_back(temp_name);
3✔
1134

1135
            // Step 1: Modify the original producer block to write to _fused_tmp
1136
            data_flow::Subset original_write_subset;
3✔
1137
            for (auto& node : producer_dataflow.nodes()) {
6✔
1138
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
6✔
1139
                if (access == nullptr || access->data() != candidate.container) continue;
6✔
1140
                if (producer_dataflow.in_degree(*access) == 0) continue;
3✔
1141

1142
                // This is the write access node — save the original subset, then redirect
1143
                for (auto& in_edge : producer_dataflow.in_edges(*access)) {
3✔
1144
                    original_write_subset = in_edge.subset();
3✔
1145
                    in_edge.set_subset({});
3✔
1146
                    in_edge.set_base_type(tmp_type);
3✔
1147
                }
3✔
1148
                access->data(temp_name);
3✔
1149
                break;
3✔
1150
            }
3✔
1151

1152
            // Step 2: Add a writeback block: container[original_subset] = _fused_tmp
1153
            control_flow::Assignments empty_assignments;
3✔
1154
            auto& wb_block = builder.add_block_after(*producer_body_, *producer_block_, empty_assignments);
3✔
1155
            auto& wb_src = builder.add_access(wb_block, temp_name);
3✔
1156
            auto& wb_dst = builder.add_access(wb_block, candidate.container);
3✔
1157
            auto& wb_tasklet = builder.add_tasklet(wb_block, data_flow::TaskletCode::assign, "_out", {"_in"});
3✔
1158
            builder.add_computational_memlet(wb_block, wb_src, wb_tasklet, "_in", {});
3✔
1159
            builder.add_computational_memlet(wb_block, wb_tasklet, "_out", wb_dst, original_write_subset);
3✔
1160

1161
            // Step 3: Copy consumer blocks after the writeback block
1162
            structured_control_flow::ControlFlowNode* last_inserted = &wb_block;
3✔
1163

1164
            for (size_t i = 0; i < consumer_body_->size(); ++i) {
6✔
1165
                auto* consumer_block = dynamic_cast<structured_control_flow::Block*>(&consumer_body_->at(i).first);
3✔
1166
                if (consumer_block == nullptr) {
3✔
1167
                    continue;
×
1168
                }
×
1169

1170
                auto& consumer_dataflow = consumer_block->dataflow();
3✔
1171

1172
                // Check if this block reads from the fusion container
1173
                bool reads_container = false;
3✔
1174
                for (auto& node : consumer_dataflow.nodes()) {
11✔
1175
                    auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
11✔
1176
                    if (access != nullptr && access->data() == candidate.container &&
11✔
1177
                        consumer_dataflow.out_degree(*access) > 0) {
11✔
1178
                        reads_container = true;
3✔
1179
                        break;
3✔
1180
                    }
3✔
1181
                }
11✔
1182
                if (!reads_container) {
3✔
1183
                    continue;
×
1184
                }
×
1185

1186
                // Insert a new block after the last inserted block in the producer's body
1187
                auto& new_block = builder.add_block_after(*producer_body_, *last_inserted, empty_assignments);
3✔
1188
                structured_control_flow::Block* empty_block = nullptr;
3✔
1189

1190
                // Deep copy all nodes from consumer block
1191
                std::unordered_map<const data_flow::DataFlowNode*, data_flow::DataFlowNode*> node_mapping;
3✔
1192
                std::unordered_map<std::string, std::string> intermediate_renames;
3✔
1193
                for (auto& node : consumer_dataflow.nodes()) {
11✔
1194
                    node_mapping[&node] = &builder.copy_node(new_block, node);
11✔
1195
                    auto* copied = node_mapping[&node];
11✔
1196
                    if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(copied)) {
11✔
1197
                        if (access_node->data() == candidate.container) {
8✔
1198
                            // Only rename read access nodes to temp; keep write access nodes
1199
                            // pointing to the original container
1200
                            if (consumer_dataflow.in_degree(node) == 0) {
4✔
1201
                                access_node->data(temp_name);
3✔
1202
                            }
3✔
1203
                        } else if (consumer_dataflow.in_degree(node) > 0 && consumer_dataflow.out_degree(node) > 0 &&
4✔
1204
                                   dynamic_cast<const types::Scalar*>(&sdfg.type(access_node->data())) != nullptr) {
4✔
1205
                            // SSA Dataflow required to check for non-local use of the access node's container.
1206
                            // Intermediate access node (e.g. from a prior BlockFusion): clone
1207
                            // its container so each inlined copy gets its own private scalar
1208
                            auto it = intermediate_renames.find(access_node->data());
×
1209
                            if (it == intermediate_renames.end()) {
×
1210
                                std::string fresh = builder.find_new_name(access_node->data());
×
1211
                                builder.add_container(fresh, sdfg.type(access_node->data()));
×
1212
                                intermediate_renames[access_node->data()] = fresh;
×
1213
                            }
×
1214
                            access_node->data(intermediate_renames[access_node->data()]);
×
1215
                        }
×
1216
                        if (access_node->data() == second_loop_.indvar()->get_name() &&
8✔
1217
                            consumer_dataflow.in_degree(node) == 0) {
8✔
1218
                            // Determine the new expression for the index variable of the second loop
1219
                            symbolic::Expression new_expr = SymEngine::null;
×
1220
                            for (auto& c : fusion_candidates_) {
×
1221
                                for (auto& [sym, expr] : c.index_mappings) {
×
1222
                                    if (symbolic::eq(sym, second_loop_.indvar())) {
×
1223
                                        new_expr = expr;
×
1224
                                        break;
×
1225
                                    }
×
1226
                                }
×
1227
                                if (!new_expr.is_null()) {
×
1228
                                    break;
×
1229
                                }
×
1230
                            }
×
1231

1232
                            if (new_expr.is_null() || symbolic::eq(new_expr, first_map_.indvar())) {
×
1233
                                // Simple case: The new expression is simply the index variable of the first map
1234
                                access_node->data(first_map_.indvar()->get_name());
×
1235
                            } else {
×
1236
                                // Complex case: Add an empty block before the new block (if necessary) and store the
1237
                                // shifted index into a new temporary variable with an assignment. Then, replace the
1238
                                // index variable with the new temporary variable
1239
                                if (!empty_block) {
×
1240
                                    empty_block =
×
1241
                                        &builder.add_block_before(*producer_body_, new_block, empty_assignments);
×
1242
                                }
×
1243
                                auto new_index_name = builder.find_new_name();
×
1244
                                builder.add_container(
×
1245
                                    new_index_name, builder.subject().type(first_map_.indvar()->get_name())
×
1246
                                );
×
1247
                                producer_body_->at(0)
×
1248
                                    .second.assignments()
×
1249
                                    .insert({symbolic::symbol(new_index_name), new_expr});
×
1250
                                access_node->data(new_index_name);
×
1251
                            }
×
1252
                        }
×
1253
                    }
8✔
1254
                }
11✔
1255

1256
                // Add memlets with index substitution (consumer indvars → producer expressions)
1257
                for (auto& edge : consumer_dataflow.edges()) {
8✔
1258
                    auto& src_node = edge.src();
8✔
1259
                    auto& dst_node = edge.dst();
8✔
1260

1261
                    const types::IType* base_type = &edge.base_type();
8✔
1262
                    data_flow::Subset new_subset;
8✔
1263
                    for (const auto& dim : edge.subset()) {
9✔
1264
                        auto new_dim = dim;
9✔
1265
                        for (const auto& [cvar, mapping] : candidate.index_mappings) {
13✔
1266
                            new_dim = symbolic::subs(new_dim, cvar, mapping);
13✔
1267
                        }
13✔
1268
                        new_dim = symbolic::expand(new_dim);
9✔
1269
                        new_subset.push_back(new_dim);
9✔
1270
                    }
9✔
1271

1272
                    // For read edges from temp scalar, use empty subset
1273
                    auto* src_access = dynamic_cast<data_flow::AccessNode*>(&src_node);
8✔
1274
                    if (src_access != nullptr && src_access->data() == candidate.container &&
8✔
1275
                        consumer_dataflow.in_degree(*src_access) == 0) {
8✔
1276
                        new_subset.clear();
3✔
1277
                        base_type = &tmp_type;
3✔
1278
                    }
3✔
1279

1280
                    builder.add_memlet(
8✔
1281
                        new_block,
8✔
1282
                        *node_mapping[&src_node],
8✔
1283
                        edge.src_conn(),
8✔
1284
                        *node_mapping[&dst_node],
8✔
1285
                        edge.dst_conn(),
8✔
1286
                        new_subset,
8✔
1287
                        *base_type,
8✔
1288
                        edge.debug_info()
8✔
1289
                    );
8✔
1290
                }
8✔
1291

1292
                last_inserted = &new_block;
3✔
1293
            }
3✔
1294
        }
3✔
1295

1296
        // Remove the consumer loop
1297
        auto* parent = second_loop_.get_parent();
3✔
1298
        auto* parent_seq = dynamic_cast<structured_control_flow::Sequence*>(parent);
3✔
1299
        if (parent_seq != nullptr) {
3✔
1300
            int idx = parent_seq->index(second_loop_);
3✔
1301
            if (idx >= 0) {
3✔
1302
                builder.remove_child(*parent_seq, static_cast<size_t>(idx));
3✔
1303
            }
3✔
1304
        }
3✔
1305
    }
3✔
1306

1307
    analysis_manager.invalidate_all();
19✔
1308
    applied_ = true;
19✔
1309
}
19✔
1310

1311
void MapFusion::to_json(nlohmann::json& j) const {
1✔
1312
    std::string second_type = "for";
1✔
1313
    if (dynamic_cast<structured_control_flow::Map*>(&second_loop_) != nullptr) {
1✔
1314
        second_type = "map";
1✔
1315
    }
1✔
1316
    j["transformation_type"] = this->name();
1✔
1317
    j["subgraph"] = {
1✔
1318
        {"0", {{"element_id", first_map_.element_id()}, {"type", "map"}}},
1✔
1319
        {"1", {{"element_id", second_loop_.element_id()}, {"type", second_type}}}
1✔
1320
    };
1✔
1321
}
1✔
1322

1323
MapFusion MapFusion::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& desc) {
1✔
1324
    auto first_map_id = desc["subgraph"]["0"]["element_id"].get<size_t>();
1✔
1325
    auto second_loop_id = desc["subgraph"]["1"]["element_id"].get<size_t>();
1✔
1326

1327
    auto first_element = builder.find_element_by_id(first_map_id);
1✔
1328
    auto second_element = builder.find_element_by_id(second_loop_id);
1✔
1329

1330
    if (first_element == nullptr) {
1✔
1331
        throw InvalidTransformationDescriptionException("Element with ID " + std::to_string(first_map_id) + " not found.");
×
1332
    }
×
1333
    if (second_element == nullptr) {
1✔
1334
        throw InvalidTransformationDescriptionException(
×
1335
            "Element with ID " + std::to_string(second_loop_id) + " not found."
×
1336
        );
×
1337
    }
×
1338

1339
    auto* first_map = dynamic_cast<structured_control_flow::Map*>(first_element);
1✔
1340
    auto* second_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(second_element);
1✔
1341

1342
    if (first_map == nullptr) {
1✔
1343
        throw InvalidTransformationDescriptionException(
×
1344
            "Element with ID " + std::to_string(first_map_id) + " is not a Map."
×
1345
        );
×
1346
    }
×
1347
    if (second_loop == nullptr) {
1✔
1348
        throw InvalidTransformationDescriptionException(
×
1349
            "Element with ID " + std::to_string(second_loop_id) + " is not a StructuredLoop."
×
1350
        );
×
1351
    }
×
1352

1353
    return MapFusion(*first_map, *second_loop);
1✔
1354
}
1✔
1355

1356
} // namespace transformations
1357
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc