• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 28106147644

24 Jun 2026 02:32PM UTC coverage: 61.922% (+0.1%) from 61.779%
28106147644

Pull #806

github

web-flow
Merge 2be414d54 into 57cc1db99
Pull Request #806: Map Collapse for Multiple targets in a neste sequence

165 of 185 new or added lines in 2 files covered. (89.19%)

419 existing lines in 30 files now uncovered.

37705 of 60891 relevant lines covered (61.92%)

1004.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.15
/opt/src/transformations/map_fusion.cpp
1
#include "sdfg/transformations/map_fusion.h"
2

3
#include <isl/ctx.h>
4
#include <isl/map.h>
5
#include <isl/options.h>
6
#include <isl/set.h>
7
#include <isl/space.h>
8
#include <symengine/solve.h>
9
#include "sdfg/analysis/arguments_analysis.h"
10
#include "sdfg/analysis/loop_analysis.h"
11

12
#include "sdfg/analysis/assumptions_analysis.h"
13
#include "sdfg/control_flow/interstate_edge.h"
14
#include "sdfg/data_flow/data_flow_graph.h"
15
#include "sdfg/structured_control_flow/block.h"
16
#include "sdfg/structured_control_flow/for.h"
17
#include "sdfg/symbolic/delinearization.h"
18
#include "sdfg/symbolic/utils.h"
19
#include "sdfg/visitor/structured_sdfg_visitor.h"
20

21
namespace sdfg {
22
namespace transformations {
23

24
class FusionConsumerSubsetVisitor : public visitor::ActualStructuredSDFGVisitor {
25
    friend MapFusion;
26

27
    std::unordered_map<std::string, const data_flow::Subset*>& target_containers_;
28
    std::unordered_map<std::string, std::vector<data_flow::Subset>> unique_subsets_per_container_;
29

30
protected:
31
    bool abort() { return true; }
1✔
32

33
public:
34
    FusionConsumerSubsetVisitor(std::unordered_map<std::string, const data_flow::Subset*>& target_containers)
35
        : target_containers_(target_containers) {}
35✔
36

37
    bool visit(sdfg::structured_control_flow::Block& block) override {
36✔
38
        auto& dataflow = block.dataflow();
36✔
39
        for (auto& node : dataflow.nodes()) {
127✔
40
            auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
127✔
41
            if (access == nullptr) {
127✔
42
                continue;
38✔
43
            }
38✔
44
            auto& container = access->data();
89✔
45

46
            auto target_it = target_containers_.find(container);
89✔
47
            if (target_it == target_containers_.end()) {
89✔
48
                continue;
47✔
49
            }
47✔
50
            auto& producer_subset = *target_it->second;
42✔
51
            auto& unique_subsets = unique_subsets_per_container_[container]; // Ensures entry exists
42✔
52

53
            // Skip write-only access nodes (consumer also writes the fusion container)
54
            if (dataflow.in_degree(*access) > 0 && dataflow.out_degree(*access) == 0) {
42✔
55
                continue;
3✔
56
            }
3✔
57
            if (dataflow.in_degree(*access) != 0 || dataflow.out_degree(*access) == 0) {
39✔
58
                return abort();
×
59
            }
×
60

61
            // Check all read memlets from this access
62
            for (auto& memlet : dataflow.out_edges(*access)) {
41✔
63
                if (memlet.type() != data_flow::MemletType::Computational) {
41✔
64
                    return abort();
×
65
                }
×
66

67
                auto& consumer_subset = memlet.subset();
41✔
68
                if (consumer_subset.size() != producer_subset.size()) {
41✔
69
                    return abort();
1✔
70
                }
1✔
71

72
                // Check if this subset is already in unique_subsets
73
                bool found = false;
40✔
74
                for (const auto& existing : unique_subsets) {
40✔
75
                    if (existing.size() != consumer_subset.size()) continue;
7✔
76
                    bool match = true;
7✔
77
                    for (size_t d = 0; d < existing.size(); ++d) {
9✔
78
                        if (!symbolic::eq(existing[d], consumer_subset[d])) {
7✔
79
                            match = false;
5✔
80
                            break;
5✔
81
                        }
5✔
82
                    }
7✔
83
                    if (match) {
7✔
84
                        found = true;
2✔
85
                        break;
2✔
86
                    }
2✔
87
                }
7✔
88
                if (!found) {
40✔
89
                    unique_subsets.push_back(consumer_subset);
38✔
90
                }
38✔
91
            }
40✔
92
        }
39✔
93
        return false;
35✔
94
    }
36✔
95

96
    bool visit(sdfg::structured_control_flow::Sequence& node) override {
35✔
97
        for (int i = 0; i < node.size(); ++i) {
70✔
98
            if (dispatch(node.at(i).first)) {
36✔
99
                return true;
1✔
100
            }
1✔
101
        }
36✔
102

103
        return false;
34✔
104
    }
35✔
105

106
    bool visit(IfElse& node) override {
×
107
        for (int i = 0; i < node.size(); ++i) {
×
108
            if (visit(node.at(i).first)) {
×
109
                return true;
×
110
            }
×
111
        }
×
112

113
        return false;
×
114
    }
×
115
};
116

117
class FusionConsumerUpdateVisitor : public visitor::ActualStructuredSDFGVisitor {
118
    friend MapFusion;
119

120
    builder::StructuredSDFGBuilder& builder_;
121
    const std::vector<MapFusion::FusionCandidate>& fusion_candidates_;
122
    const std::vector<std::string>& candidate_temps_;
123

124
public:
125
    FusionConsumerUpdateVisitor(
126
        builder::StructuredSDFGBuilder& builder,
127
        const std::vector<MapFusion::FusionCandidate>& fusion_candidates,
128
        const std::vector<std::string>& candidate_temps
129
    )
130
        : builder_(builder), fusion_candidates_(fusion_candidates), candidate_temps_(candidate_temps) {}
17✔
131

132
    bool dispatch_partial_sequence(Sequence& node, size_t first, size_t end) {
17✔
133
        for (int i = first; i < end; ++i) {
36✔
134
            if (dispatch(node.at(i).first)) {
19✔
135
                return true;
×
136
            }
×
137
        }
19✔
138

139
        return false;
17✔
140
    }
17✔
141

142
    bool visit(sdfg::structured_control_flow::Block& block) override {
19✔
143
        auto& dataflow = block.dataflow();
19✔
144

145
        // Snapshot access nodes before mutation: adding new access nodes below
146
        // would rehash dataflow.nodes_ and invalidate the range iterator.
147
        std::vector<data_flow::AccessNode*> access_nodes;
19✔
148
        for (auto& node : dataflow.nodes()) {
68✔
149
            auto* an = dynamic_cast<data_flow::AccessNode*>(&node);
68✔
150
            if (an != nullptr && dataflow.out_degree(*an) > 0) {
68✔
151
                access_nodes.push_back(an);
28✔
152
            }
28✔
153
        }
68✔
154

155
        for (auto* access : access_nodes) {
28✔
156
            std::string original_container = access->data();
28✔
157

158
            // Match each out-edge against a fusion candidate.
159
            struct Match {
28✔
160
                data_flow::Memlet* memlet;
28✔
161
                size_t cand_idx;
28✔
162
            };
28✔
163
            std::vector<Match> matches;
28✔
164
            for (auto& memlet : dataflow.out_edges(*access)) {
30✔
165
                if (memlet.type() != data_flow::MemletType::Computational) {
30✔
166
                    continue;
×
167
                }
×
168
                const auto& memlet_subset = memlet.subset();
30✔
169
                for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
41✔
170
                    auto& candidate = fusion_candidates_[cand_idx];
32✔
171
                    if (original_container != candidate.container) {
32✔
172
                        continue;
9✔
173
                    }
9✔
174
                    if (memlet_subset.size() != candidate.consumer_subset.size()) {
23✔
175
                        continue;
×
176
                    }
×
177
                    bool subset_matches = true;
23✔
178
                    for (size_t d = 0; d < memlet_subset.size(); ++d) {
47✔
179
                        if (!symbolic::eq(memlet_subset[d], candidate.consumer_subset[d])) {
26✔
180
                            subset_matches = false;
2✔
181
                            break;
2✔
182
                        }
2✔
183
                    }
26✔
184
                    if (subset_matches) {
23✔
185
                        matches.push_back({&memlet, cand_idx});
21✔
186
                        break;
21✔
187
                    }
21✔
188
                }
23✔
189
            }
30✔
190
            if (matches.empty()) {
28✔
191
                continue;
9✔
192
            }
9✔
193

194
            // Group matches by candidate index.
195
            std::unordered_set<size_t> distinct_cands;
19✔
196
            for (auto& m : matches) {
21✔
197
                distinct_cands.insert(m.cand_idx);
21✔
198
            }
21✔
199

200
            if (distinct_cands.size() == 1) {
19✔
201
                // Fast path: all matched out-edges resolve to the same candidate.
202
                // Mutate the shared access node in place — this preserves the
203
                // existing semantics for the single-read-per-container case.
204
                size_t cand_idx = *distinct_cands.begin();
18✔
205
                const auto& temp_name = candidate_temps_[cand_idx];
18✔
206
                auto& temp_type = builder_.subject().type(temp_name);
18✔
207

208
                access->data(temp_name);
18✔
209

210
                for (auto& m : matches) {
19✔
211
                    m.memlet->set_subset({});
19✔
212
                    m.memlet->set_base_type(temp_type);
19✔
213
                }
19✔
214

215
                for (auto& in_edge : dataflow.in_edges(*access)) {
18✔
216
                    in_edge.set_subset({});
×
217
                    in_edge.set_base_type(temp_type);
×
218
                }
×
219
            } else {
18✔
220
                // Stencil-like case: a single access node feeds reads at
221
                // multiple distinct subsets (e.g. T[j-1] and T[j+1] sharing
222
                // one AccessNode). Each must be rewired to its own
223
                // candidate-specific temp scalar — otherwise mutating
224
                // `access->data()` once per candidate makes all reads
225
                // collapse onto the last temp, e.g. T[j+1]-T[j] becomes
226
                // tmp-tmp == 0.
227
                //
228
                // Fix: for each distinct candidate, create one fresh
229
                // AccessNode for its temp scalar and redirect the matched
230
                // edges from the shared access node to the fresh nodes.
231
                struct PendingRedirect {
1✔
232
                    data_flow::DataFlowNode* dst;
1✔
233
                    std::string src_conn;
1✔
234
                    std::string dst_conn;
1✔
235
                    DebugInfo debug_info;
1✔
236
                    size_t cand_idx;
1✔
237
                    const data_flow::Memlet* memlet_to_remove;
1✔
238
                };
1✔
239
                std::vector<PendingRedirect> pending;
1✔
240
                pending.reserve(matches.size());
1✔
241
                for (auto& m : matches) {
2✔
242
                    pending.push_back(
2✔
243
                        {&m.memlet->dst(),
2✔
244
                         m.memlet->src_conn(),
2✔
245
                         m.memlet->dst_conn(),
2✔
246
                         m.memlet->debug_info(),
2✔
247
                         m.cand_idx,
2✔
248
                         m.memlet}
2✔
249
                    );
2✔
250
                }
2✔
251

252
                std::unordered_map<size_t, data_flow::AccessNode*> per_cand_node;
1✔
253
                for (auto& p : pending) {
2✔
254
                    auto it = per_cand_node.find(p.cand_idx);
2✔
255
                    if (it == per_cand_node.end()) {
2✔
256
                        auto& fresh = builder_.add_access(block, candidate_temps_[p.cand_idx]);
2✔
257
                        it = per_cand_node.emplace(p.cand_idx, &fresh).first;
2✔
258
                    }
2✔
259
                    auto& temp_type = builder_.subject().type(candidate_temps_[p.cand_idx]);
2✔
260
                    builder_.remove_memlet(block, *p.memlet_to_remove);
2✔
261
                    builder_.add_memlet(block, *it->second, p.src_conn, *p.dst, p.dst_conn, {}, temp_type, p.debug_info);
2✔
262
                }
2✔
263

264
                // If the original shared access node now has no edges at all
265
                // it is dangling and should be removed. Keep it if it still
266
                // has out-edges (unmatched reads of the original container)
267
                // or in-edges (writes to the original container).
268
                if (dataflow.out_degree(*access) == 0 && dataflow.in_degree(*access) == 0) {
1✔
269
                    builder_.remove_node(block, *access);
1✔
270
                }
1✔
271
            }
1✔
272
        }
19✔
273
        return false;
19✔
274
    }
19✔
275

276
    bool visit(sdfg::structured_control_flow::Sequence& node) override {
×
277
        for (int i = 0; i < node.size(); ++i) {
×
278
            if (dispatch(node.at(i).first)) {
×
279
                return true;
×
280
            }
×
281
        }
×
282

283
        return false;
×
284
    }
×
285

286
    bool visit(IfElse& node) override {
×
287
        for (int i = 0; i < node.size(); ++i) {
×
288
            if (visit(node.at(i).first)) {
×
289
                return true;
×
290
            }
×
291
        }
×
292

293
        return false;
×
294
    }
×
295
};
296

297
MapFusion::MapFusion(
298
    structured_control_flow::Map& first_map,
299
    structured_control_flow::StructuredLoop& second_loop,
300
    bool require_consecutive,
301
    bool allow_init_hoist
302
)
303
    : first_map_(first_map), second_loop_(second_loop), require_consecutive_(require_consecutive),
43✔
304
      allow_init_hoist_(allow_init_hoist) {}
43✔
305

306
std::string MapFusion::name() const { return "MapFusion"; }
2✔
307

308
std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> MapFusion::solve_subsets(
309
    const data_flow::Subset& producer_subset,
310
    const data_flow::Subset& consumer_subset,
311
    const std::vector<structured_control_flow::StructuredLoop*>& producer_loops,
312
    const std::vector<structured_control_flow::StructuredLoop*>& consumer_loops,
313
    const symbolic::Assumptions& producer_assumptions,
314
    const symbolic::Assumptions& consumer_assumptions,
315
    bool invert_range_check
316
) {
38✔
317
    // Delinearize subsets to recover multi-dimensional structure from linearized accesses
318
    // e.g. T[i*N + j] with assumptions on bounds -> T[i, j]
319
    auto producer_sub = producer_subset;
38✔
320
    if (producer_sub.size() == 1) {
38✔
321
        auto producer_result = symbolic::delinearize(producer_sub.at(0), producer_assumptions);
29✔
322
        if (producer_result.success) {
29✔
323
            producer_sub = producer_result.indices;
26✔
324
        }
26✔
325
    }
29✔
326
    auto consumer_sub = consumer_subset;
38✔
327
    if (consumer_sub.size() == 1) {
38✔
328
        auto consumer_result = symbolic::delinearize(consumer_sub.at(0), consumer_assumptions);
29✔
329
        if (consumer_result.success) {
29✔
330
            consumer_sub = consumer_result.indices;
26✔
331
        }
26✔
332
    }
29✔
333

334
    // Subset dimensions must match
335
    if (producer_sub.size() != consumer_sub.size()) {
38✔
336
        return {};
×
337
    }
×
338
    if (producer_sub.empty()) {
38✔
339
        return {};
×
340
    }
×
341

342
    // Extract producer indvars
343
    SymEngine::vec_sym producer_vars;
38✔
344
    for (auto* loop : producer_loops) {
51✔
345
        producer_vars.push_back(SymEngine::rcp_static_cast<const SymEngine::Symbol>(loop->indvar()));
51✔
346
    }
51✔
347

348
    // Step 1: Solve the linear equation system using SymEngine
349
    // System: producer_sub[d] - consumer_sub[d] = 0, for each dimension d
350
    // Solve for producer_vars in terms of consumer_vars and parameters
351
    SymEngine::vec_basic equations;
38✔
352
    for (size_t d = 0; d < producer_sub.size(); ++d) {
89✔
353
        equations.push_back(symbolic::sub(producer_sub.at(d), consumer_sub.at(d)));
51✔
354
    }
51✔
355

356
    // Need exactly as many equations as unknowns for a unique solution.
357
    // Underdetermined systems (e.g. linearized access with multiple loop vars)
358
    // cannot be uniquely solved and would crash linsolve.
359
    if (equations.size() != producer_vars.size()) {
38✔
360
        return {};
×
361
    }
×
362

363
    SymEngine::vec_basic solution;
38✔
364
    try {
38✔
365
        solution = SymEngine::linsolve(equations, producer_vars);
38✔
366
    } catch (...) {
38✔
367
        return {};
×
368
    }
×
369
    if (solution.size() != producer_vars.size()) {
38✔
370
        return {};
×
371
    }
×
372
    // Build consumer var set for atom validation
373
    symbolic::SymbolSet consumer_var_set;
38✔
374
    for (auto* loop : consumer_loops) {
53✔
375
        consumer_var_set.insert(loop->indvar());
53✔
376
    }
53✔
377

378
    std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> mappings;
38✔
379
    for (size_t i = 0; i < producer_vars.size(); ++i) {
89✔
380
        auto& sol = solution[i];
51✔
381

382
        // Check for invalid solutions
383
        if (SymEngine::is_a<SymEngine::NaN>(*sol) || SymEngine::is_a<SymEngine::Infty>(*sol)) {
51✔
384
            return {};
×
385
        }
×
386

387
        // Validate that solution atoms are consumer vars or parameters
388
        for (const auto& atom : symbolic::atoms(sol)) {
53✔
389
            if (consumer_var_set.count(atom)) {
53✔
390
                continue;
53✔
391
            }
53✔
392
            bool is_param = false;
×
393
            auto it = consumer_assumptions.find(atom);
×
394
            if (it != consumer_assumptions.end() && it->second.constant()) {
×
395
                is_param = true;
×
396
            }
×
397
            if (!is_param) {
×
398
                it = producer_assumptions.find(atom);
×
399
                if (it != producer_assumptions.end() && it->second.constant()) {
×
400
                    is_param = true;
×
401
                }
×
402
            }
×
403
            if (!is_param) {
×
404
                return {};
×
405
            }
×
406
        }
×
407

408
        mappings.push_back({symbolic::symbol(producer_vars[i]->get_name()), symbolic::expand(sol)});
51✔
409
    }
51✔
410
    // Step 2: ISL integrality validation via map composition
411
    // Build an unconstrained producer access map (no domain bounds on producer vars).
412
    // In map fusion, the producer's computation is inlined into the consumer, so
413
    // the producer's original iteration domain is irrelevant. We only need to verify
414
    // that the equation system has an INTEGER solution for every consumer point.
415
    symbolic::Assumptions unconstrained_producer;
38✔
416
    for (auto* loop : producer_loops) {
51✔
417
        symbolic::Assumption a(loop->indvar());
51✔
418
        a.constant(false);
51✔
419
        unconstrained_producer[loop->indvar()] = a;
51✔
420
    }
51✔
421
    for (const auto& [sym, assump] : producer_assumptions) {
157✔
422
        if (assump.constant() && unconstrained_producer.find(sym) == unconstrained_producer.end()) {
157✔
423
            unconstrained_producer[sym] = assump;
51✔
424
        }
51✔
425
    }
157✔
426

427
    std::string producer_map_str = symbolic::expression_to_map_str(producer_sub, unconstrained_producer);
38✔
428
    // Build consumer access map with full domain constraints
429
    std::string consumer_map_str = symbolic::expression_to_map_str(consumer_sub, consumer_assumptions);
38✔
430

431
    isl_ctx* ctx = isl_ctx_alloc();
38✔
432
    isl_options_set_on_error(ctx, ISL_ON_ERROR_CONTINUE);
38✔
433

434
    isl_map* producer_map = isl_map_read_from_str(ctx, producer_map_str.c_str());
38✔
435
    isl_map* consumer_map = isl_map_read_from_str(ctx, consumer_map_str.c_str());
38✔
436

437
    if (!producer_map || !consumer_map) {
38✔
438
        if (producer_map) isl_map_free(producer_map);
×
439
        if (consumer_map) isl_map_free(consumer_map);
×
440
        isl_ctx_free(ctx);
×
441
        return {};
×
442
    }
×
443

444
    // Align parameters between the two maps
445
    isl_space* params_p = isl_space_params(isl_map_get_space(producer_map));
38✔
446
    isl_space* params_c = isl_space_params(isl_map_get_space(consumer_map));
38✔
447
    isl_space* unified = isl_space_align_params(isl_space_copy(params_p), isl_space_copy(params_c));
38✔
448
    isl_space_free(params_p);
38✔
449
    isl_space_free(params_c);
38✔
450

451
    producer_map = isl_map_align_params(producer_map, isl_space_copy(unified));
38✔
452
    consumer_map = isl_map_align_params(consumer_map, isl_space_copy(unified));
38✔
453

454
    // Save consumer domain before consuming consumer_map in composition
455
    isl_set* consumer_domain = isl_map_domain(isl_map_copy(consumer_map));
38✔
456

457
    // Compute composition: consumer_access ∘ inverse(producer_access)
458
    // This checks whether the equation system producer_subset = consumer_subset
459
    // has an integer solution for each consumer domain point.
460
    isl_map* producer_inverse = isl_map_reverse(producer_map);
38✔
461
    isl_map* composition = isl_map_apply_range(consumer_map, producer_inverse);
38✔
462

463
    // Check single-valuedness: each consumer point maps to at most one producer point
464
    bool single_valued = isl_map_is_single_valued(composition) == isl_bool_true;
38✔
465

466
    // Check domain coverage: every consumer point has a valid integer mapping
467
    isl_set* comp_domain = isl_map_domain(composition);
38✔
468

469
    bool domain_covered = isl_set_is_subset(consumer_domain, comp_domain) == isl_bool_true;
38✔
470

471
    isl_set_free(comp_domain);
38✔
472
    isl_set_free(consumer_domain);
38✔
473

474
    // Step 3: Verify producer write range covers consumer read range.
475
    // The producer only writes a subset of the array if its loops have restricted bounds.
476
    // Fusion is invalid if the consumer reads elements the producer never writes.
477
    bool range_covered = false;
38✔
478
    if (single_valued && domain_covered) {
38✔
479
        std::string constrained_producer_map_str = symbolic::expression_to_map_str(producer_sub, producer_assumptions);
36✔
480
        isl_map* constrained_producer = isl_map_read_from_str(ctx, constrained_producer_map_str.c_str());
36✔
481
        isl_map* consumer_map_copy = isl_map_read_from_str(ctx, consumer_map_str.c_str());
36✔
482

483
        if (constrained_producer && consumer_map_copy) {
36✔
484
            constrained_producer = isl_map_align_params(constrained_producer, isl_space_copy(unified));
36✔
485
            consumer_map_copy = isl_map_align_params(consumer_map_copy, isl_space_copy(unified));
36✔
486

487
            isl_set* producer_range = isl_map_range(constrained_producer);
36✔
488
            isl_set* consumer_range = isl_map_range(consumer_map_copy);
36✔
489

490
            // When arguments are swapped (ConsumerIntoProducer), the "producer"/"consumer"
491
            // labels are inverted. Flip the subset check to always verify:
492
            // actual_consumer_read_range ⊆ actual_producer_write_range
493
            if (invert_range_check) {
36✔
494
                range_covered = isl_set_is_subset(producer_range, consumer_range) == isl_bool_true;
4✔
495
            } else {
32✔
496
                range_covered = isl_set_is_subset(consumer_range, producer_range) == isl_bool_true;
32✔
497
            }
32✔
498

499
            isl_set_free(producer_range);
36✔
500
            isl_set_free(consumer_range);
36✔
501
        } else {
36✔
502
            if (constrained_producer) isl_map_free(constrained_producer);
×
503
            if (consumer_map_copy) isl_map_free(consumer_map_copy);
×
504
        }
×
505
    }
36✔
506

507
    isl_space_free(unified);
38✔
508
    isl_ctx_free(ctx);
38✔
509

510
    if (!single_valued || !domain_covered || !range_covered) {
38✔
511
        return {};
5✔
512
    }
5✔
513

514
    return mappings;
33✔
515
}
38✔
516

517
bool MapFusion::find_write_location(
518
    structured_control_flow::StructuredLoop& loop,
519
    const std::string& container,
520
    std::vector<structured_control_flow::StructuredLoop*>& loops,
521
    structured_control_flow::Sequence*& body,
522
    structured_control_flow::Block*& block
523
) {
4✔
524
    loops.push_back(&loop);
4✔
525
    auto& seq = loop.root();
4✔
526

527
    for (size_t i = 0; i < seq.size(); ++i) {
10✔
528
        auto& child = seq.at(i).first;
6✔
529

530
        if (auto* blk = dynamic_cast<structured_control_flow::Block*>(&child)) {
6✔
531
            // Check if this block writes to the container
532
            auto& dataflow = blk->dataflow();
4✔
533
            for (auto& node : dataflow.nodes()) {
14✔
534
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
14✔
535
                if (access == nullptr || access->data() != container) {
14✔
536
                    continue;
12✔
537
                }
12✔
538
                // Write access: has incoming edges (sink node)
539
                if (dataflow.in_degree(*access) > 0 && dataflow.out_degree(*access) == 0) {
2✔
540
                    if (block != nullptr) {
2✔
541
                        // Multiple write blocks found — ambiguous
542
                        return false;
×
543
                    }
×
544
                    body = &seq;
2✔
545
                    block = blk;
2✔
546
                }
2✔
547
            }
2✔
548
        } else if (auto* nested_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(&child)) {
4✔
549
            if (!find_write_location(*nested_loop, container, loops, body, block)) {
2✔
550
                return false;
×
551
            }
×
552
            // If we didn't find the write in this subtree, pop the loop back off
553
            if (loops.back() != &loop && block == nullptr) {
2✔
554
                // The recursive call already popped — but we need to check
555
            }
×
556
        }
2✔
557
    }
6✔
558

559
    // If we didn't find the write in this subtree, remove this loop from the chain
560
    if (block == nullptr) {
4✔
561
        loops.pop_back();
×
562
    }
×
563

564
    return true;
4✔
565
}
4✔
566

567
bool MapFusion::find_read_location(
568
    structured_control_flow::StructuredLoop& loop,
569
    const std::string& container,
570
    std::vector<structured_control_flow::StructuredLoop*>& loops,
571
    structured_control_flow::Sequence*& body
572
) {
2✔
573
    loops.push_back(&loop);
2✔
574
    auto& seq = loop.root();
2✔
575

576
    for (size_t i = 0; i < seq.size(); ++i) {
5✔
577
        auto& child = seq.at(i).first;
3✔
578

579
        if (auto* blk = dynamic_cast<structured_control_flow::Block*>(&child)) {
3✔
580
            // Check if this block reads from the container
581
            auto& dataflow = blk->dataflow();
2✔
582
            for (auto& node : dataflow.nodes()) {
8✔
583
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
8✔
584
                if (access == nullptr || access->data() != container) {
8✔
585
                    continue;
7✔
586
                }
7✔
587
                // Read access: has outgoing edges (source node)
588
                if (dataflow.in_degree(*access) == 0 && dataflow.out_degree(*access) > 0) {
1✔
589
                    if (body != nullptr && body != &seq) {
1✔
590
                        // Reads at different sequence levels — ambiguous
591
                        return false;
×
592
                    }
×
593
                    body = &seq;
1✔
594
                }
1✔
595
            }
1✔
596
        } else if (auto* nested_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(&child)) {
2✔
597
            if (!find_read_location(*nested_loop, container, loops, body)) {
1✔
598
                return false;
×
599
            }
×
600
        }
1✔
601
    }
3✔
602

603
    // If we didn't find any reads in this subtree, remove this loop from the chain
604
    if (body == nullptr) {
2✔
605
        loops.pop_back();
×
606
    }
×
607

608
    return true;
2✔
609
}
2✔
610

611
bool MapFusion::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
41✔
612
    fusion_candidates_.clear();
41✔
613

614
    // no use in fusing empty loops. Also presumed to not be empty further down
615
    if (first_map_.root().size() == 0 || second_loop_.root().size() == 0) {
41✔
616
        return false;
×
617
    }
×
618

619
    // Criterion: Get parent scope and verify both loops are sequential children
620
    auto* first_parent = first_map_.get_parent();
41✔
621
    auto* second_parent = second_loop_.get_parent();
41✔
622
    if (first_parent == nullptr || second_parent == nullptr) {
41✔
623
        return false;
×
624
    }
×
625
    if (first_parent != second_parent) {
41✔
626
        return false;
×
627
    }
×
628

629
    auto* parent_sequence = dynamic_cast<structured_control_flow::Sequence*>(first_parent);
41✔
630
    if (parent_sequence == nullptr) {
41✔
631
        return false;
×
632
    }
×
633

634
    int first_index = parent_sequence->index(first_map_);
41✔
635
    int second_index = parent_sequence->index(second_loop_);
41✔
636
    if (first_index == -1 || second_index == -1) {
41✔
637
        return false;
×
638
    }
×
639
    if (require_consecutive_ && second_index != first_index + 1) {
41✔
640
        return false;
1✔
641
    }
1✔
642

643
    // Criterion: Transition between maps should have no assignments
644
    if (require_consecutive_) {
40✔
645
        auto& transition = parent_sequence->at(first_index).second;
40✔
646
        if (!transition.empty()) {
40✔
647
            return false;
×
648
        }
×
649
    }
40✔
650
    // Determine fusion pattern based on nesting properties
651
    auto& loop_analysis = analysis_manager.get<analysis::LoopAnalysis>();
40✔
652
    auto first_loop_info = loop_analysis.loop_info(&first_map_);
40✔
653
    auto second_loop_info = loop_analysis.loop_info(&second_loop_);
40✔
654

655
    auto limit_depth = 0;
40✔
656

657
    bool first_nested = first_loop_info.is_perfectly_nested;
40✔
658
    bool second_nested = second_loop_info.is_perfectly_nested;
40✔
659

660
    // Both non-perfectly-nested: not supported
661
    if (!first_nested && !second_nested) {
40✔
662
        return false;
1✔
663
    }
1✔
664

665
    if (first_nested && second_nested) {
39✔
666
        // Pattern 1: Both perfectly nested — producer into consumer (original path)
667
        direction_ = FusionDirection::ProducerIntoConsumer;
36✔
668
    } else if (!first_nested && second_nested) {
36✔
669
        // Pattern 2: Producer non-perfectly-nested, consumer perfectly nested
670
        direction_ = FusionDirection::ConsumerIntoProducer;
2✔
671
    } else {
2✔
672
        // Reverse Pattern 2: Producer perfectly nested, consumer non-perfectly-nested
673
        direction_ = FusionDirection::ProducerIntoConsumer;
1✔
674
    }
1✔
675

676
    // The side being inlined must be all-parallel (all Maps) so iterations can be reordered.
677
    // ProducerIntoConsumer: the producer is replicated at each consumer site and must be
678
    // reorderable, so it must be all-parallel. The consumer is normally required to be
679
    // all-parallel too, because a sequential (For) loop would re-execute the inlined producer
680
    // on every iteration (e.g. init T=0 fused into For(k){T+=A[k]} re-initializes each k).
681
    //
682
    // Reduction branch: we relax the consumer requirement when the consumer is a perfect nest
683
    // (parallel outer band + inner sequential For, i.e. a reduction). A fully-parallel producer
684
    // that is *streamed element-by-element* inside the reduction loop can still be inlined
685
    // soundly (e.g. scale -> max: max(M, A[i,j,k]/d)). The element-streaming safety conditions
686
    // are verified once the fusion candidates are known (see consumer_reduction_branch below):
687
    //   (1) the fused container must not be written by the consumer (no loop-carried
688
    //       accumulator), and
689
    //   (2) its consumer read subset must depend on an inner sequential loop indvar, so the
690
    //       inlined producer runs once per element rather than per init position.
691
    // These keep init-into-reduction (T=0 followed by For(k){T+=...}) rejected.
692
    // ConsumerIntoProducer: only the consumer (inlined side) must be all-parallel.
693
    bool consumer_reduction_branch = false;
39✔
694
    if (direction_ == FusionDirection::ProducerIntoConsumer) {
39✔
695
        if (!first_loop_info.is_perfectly_parallel) {
37✔
696
            return false;
×
697
        } else if (!second_loop_info.is_perfectly_parallel) {
37✔
698
            if (!second_loop_info.is_perfectly_nested) {
2✔
699
                return false;
×
700
            }
×
701
            consumer_reduction_branch = true;
2✔
702
        }
2✔
703
    } else {
37✔
704
        if (!second_loop_info.is_perfectly_parallel) {
2✔
705
            return false;
×
706
        }
×
707
    }
2✔
708

709
    // Locate producer write point
710
    producer_loops_.clear();
39✔
711
    producer_body_ = nullptr;
39✔
712
    producer_block_ = nullptr;
39✔
713

714
    if (first_nested) {
39✔
715
        // Perfectly nested: walk the at(0).first chain
716
        producer_loops_.push_back(&first_map_);
37✔
717
        producer_body_ = &first_map_.root();
37✔
718
        structured_control_flow::ControlFlowNode* node = &first_map_.root().at(0).first;
37✔
719
        int level = 1;
37✔
720
        while (auto* nested = dynamic_cast<structured_control_flow::StructuredLoop*>(node)) {
49✔
721
            if (limit_depth && ++level > limit_depth) {
13✔
722
                break;
×
723
            }
×
724
            producer_loops_.push_back(nested);
13✔
725
            producer_body_ = &nested->root();
13✔
726
            if (nested->root().size() == 0) return false;
13✔
727
            node = &nested->root().at(0).first;
12✔
728
        }
12✔
729
        producer_block_ = dynamic_cast<structured_control_flow::Block*>(node);
36✔
730
        if (producer_block_ == nullptr) {
36✔
731
            return false;
×
732
        }
×
733
        // If the body has multiple children, the at(0) walk does not guarantee
734
        // we found the correct (or unique) write block. Fall back to deferred
735
        // find_write_location resolution.
736
        if (producer_body_->size() != 1) {
36✔
737
            producer_block_ = nullptr;
×
738
            // Keep producer_loops_ and producer_body_ from the walk — they are
739
            // still valid for the loop chain. find_write_location will re-resolve
740
            // the block within producer_body_.
741
        }
×
742
    } else {
36✔
743
        // Non-perfectly-nested: search recursively for the write block
744
        // We need to know which containers to look for, but we don't know them yet.
745
        // Defer write location search until after fusion_containers are identified.
746
    }
2✔
747

748
    // Locate consumer read point
749
    consumer_loops_.clear();
38✔
750
    consumer_body_ = nullptr;
38✔
751

752
    if (second_nested) {
38✔
753
        // Perfectly nested: walk the at(0).first chain through all loop types.
754
        // Reduction patterns (e.g. Map{Map{For{T[i,j]+=...}}}) are rejected by
755
        // the is_perfectly_parallel check — For loops make it non-parallel.
756
        consumer_loops_.push_back(&second_loop_);
37✔
757
        consumer_body_ = &second_loop_.root();
37✔
758
        structured_control_flow::ControlFlowNode* node = &second_loop_.root().at(0).first;
37✔
759
        int level = 1;
37✔
760
        while (auto* nested = dynamic_cast<structured_control_flow::StructuredLoop*>(node)) {
51✔
761
            if (limit_depth && ++level > limit_depth) {
15✔
762
                break;
×
763
            }
×
764
            consumer_loops_.push_back(nested);
15✔
765
            consumer_body_ = &nested->root();
15✔
766
            if (nested->root().size() == 0) return false;
15✔
767
            node = &nested->root().at(0).first;
14✔
768
        }
14✔
769
    } else {
37✔
770
        // Non-perfectly-nested: defer read location search until after fusion_containers are identified.
771
    }
1✔
772

773
    // Get arguments analysis to identify inputs/outputs of each loop
774
    auto& arguments_analysis = analysis_manager.get<analysis::ArgumentsAnalysis>();
37✔
775
    auto first_args = arguments_analysis.arguments(analysis_manager, first_map_);
37✔
776
    auto second_args = arguments_analysis.arguments(analysis_manager, second_loop_);
37✔
777

778
    std::unordered_set<std::string> first_inputs;
37✔
779
    std::unordered_set<std::string> first_outputs;
37✔
780
    for (const auto& [name, arg] : first_args) {
122✔
781
        if (arg.is_output) {
122✔
782
            first_outputs.insert(name);
38✔
783
        }
38✔
784
        if (arg.is_input) {
122✔
785
            first_inputs.insert(name);
122✔
786
        }
122✔
787
    }
122✔
788

789
    std::unordered_set<std::string> second_outputs;
37✔
790
    for (const auto& [name, arg] : second_args) {
129✔
791
        if (arg.is_output) {
129✔
792
            second_outputs.insert(name);
38✔
793
        }
38✔
794
    }
129✔
795

796
    // First pass: identify fusion containers (producer writes, consumer reads)
797
    std::unordered_set<std::string> fusion_containers;
37✔
798
    for (const auto& [name, arg] : second_args) {
129✔
799
        if (first_outputs.contains(name) && arg.is_input) {
129✔
800
            fusion_containers.insert(name);
36✔
801
        }
36✔
802
    }
129✔
803
    if (fusion_containers.empty()) {
37✔
804
        return false;
1✔
805
    }
1✔
806

807
    // Second pass: check for conflicts on non-fusion containers
808
    for (const auto& [name, arg] : second_args) {
124✔
809
        bool is_fusion = fusion_containers.contains(name);
124✔
810
        if (first_outputs.contains(name) && arg.is_output && !is_fusion) {
124✔
811
            return false;
×
812
        }
×
813
        if (first_inputs.contains(name) && arg.is_output && !is_fusion) {
124✔
814
            return false;
1✔
815
        }
1✔
816
    }
124✔
817

818
    // Now that we know the fusion containers, resolve deferred locations
819
    if (producer_block_ == nullptr) {
35✔
820
        // Non-perfectly-nested producer (or perfectly-nested with multi-block body):
821
        // find write location for the first fusion container.
822
        // All fusion containers must be written at the same block for this to work.
823
        for (const auto& container : fusion_containers) {
2✔
824
            std::vector<structured_control_flow::StructuredLoop*> write_loops;
2✔
825
            structured_control_flow::Sequence* write_body = nullptr;
2✔
826
            structured_control_flow::Block* write_block = nullptr;
2✔
827

828
            if (!find_write_location(first_map_, container, write_loops, write_body, write_block)) {
2✔
829
                return false;
×
830
            }
×
831
            if (write_block == nullptr) {
2✔
832
                return false;
×
833
            }
×
834

835
            if (producer_block_ == nullptr) {
2✔
836
                // First container: set the locations
837
                producer_loops_ = write_loops;
2✔
838
                producer_body_ = write_body;
2✔
839
                producer_block_ = write_block;
2✔
840
            } else {
2✔
841
                // Subsequent containers must be in the same block
842
                if (write_block != producer_block_) {
×
843
                    return false;
×
844
                }
×
845
            }
×
846
        }
2✔
847
    }
2✔
848

849
    if (!second_nested) {
35✔
850
        // Non-perfectly-nested consumer: find read location for the first fusion container
851
        // All fusion containers must be read at the same sequence for this to work
852
        for (const auto& container : fusion_containers) {
1✔
853
            std::vector<structured_control_flow::StructuredLoop*> read_loops;
1✔
854
            structured_control_flow::Sequence* read_body = nullptr;
1✔
855

856
            if (!find_read_location(second_loop_, container, read_loops, read_body)) {
1✔
857
                return false;
×
858
            }
×
859
            if (read_body == nullptr) {
1✔
860
                return false;
×
861
            }
×
862

863
            if (consumer_body_ == nullptr) {
1✔
864
                // First container: set the locations
865
                consumer_loops_ = read_loops;
1✔
866
                consumer_body_ = read_body;
1✔
867
            } else {
1✔
868
                // Subsequent containers must be at the same sequence
869
                if (read_body != consumer_body_) {
×
870
                    return false;
×
871
                }
×
872
            }
×
873
        }
1✔
874
    }
1✔
875

876
    // Get assumptions for the resolved write/read locations
877
    // Include trivial bounds from types to help delinearization with symbolic strides
878
    auto& assumptions_analysis = analysis_manager.get<analysis::AssumptionsAnalysis>();
35✔
879
    auto& producer_assumptions = assumptions_analysis.get(*producer_block_, true);
35✔
880
    auto& consumer_assumptions = assumptions_analysis.get(consumer_body_->at(0).first, true);
35✔
881

882
    // Check if producer actually reads a fusion container in the dataflow.
883
    // If so, ProducerIntoConsumer is unsafe (original producer loop mutates the array
884
    // before the inlined copy reads it). Force ConsumerIntoProducer.
885
    // We check the dataflow directly rather than ArgumentsAnalysis, because the latter
886
    // conservatively marks written containers as also read.
887
    if (direction_ == FusionDirection::ProducerIntoConsumer) {
35✔
888
        auto& first_dataflow_check = producer_block_->dataflow();
33✔
889
        bool producer_reads_fusion = false;
33✔
890
        for (const auto& container : fusion_containers) {
33✔
891
            for (auto& node : first_dataflow_check.nodes()) {
106✔
892
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
106✔
893
                if (access != nullptr && access->data() == container && first_dataflow_check.out_degree(*access) > 0) {
106✔
894
                    producer_reads_fusion = true;
2✔
895
                    break;
2✔
896
                }
2✔
897
            }
106✔
898
            if (producer_reads_fusion) break;
33✔
899
        }
33✔
900
        if (producer_reads_fusion) {
33✔
901
            direction_ = FusionDirection::ConsumerIntoProducer;
2✔
902
            // Re-check: consumer must be all-parallel for ConsumerIntoProducer
903
            if (!second_loop_info.is_perfectly_parallel) {
2✔
904
                return false;
×
905
            }
×
906
        }
2✔
907
    }
33✔
908

909
    // ProducerIntoConsumer only deep-copies producer_block_ into the consumer body.
910
    // If the producer body has multiple blocks (e.g. from prior BlockFusion merging
911
    // a previous fusion's writeback + inlined blocks), the write block may depend on
912
    // intermediates produced by earlier blocks that would NOT be copied. Reject.
913
    if (direction_ == FusionDirection::ProducerIntoConsumer && producer_body_->size() > 1) {
35✔
914
        return false;
×
915
    }
×
916

917
    std::unordered_map<std::string, const data_flow::Subset*> producer_subsets;
35✔
918

919
    // For each fusion container, find the producer memlet and collect unique consumer subsets
920
    auto& first_dataflow = producer_block_->dataflow();
35✔
921
    for (const auto& container : fusion_containers) {
35✔
922
        // Find unique producer write in first map
923
        data_flow::Memlet* producer_memlet = nullptr;
35✔
924

925
        for (auto& node : first_dataflow.nodes()) {
113✔
926
            auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
113✔
927
            if (access == nullptr || access->data() != container) {
113✔
928
                continue;
76✔
929
            }
76✔
930
            // Skip read-only access nodes (producer reads the fusion container)
931
            if (first_dataflow.in_degree(*access) == 0) {
37✔
932
                continue;
2✔
933
            }
2✔
934
            // Write access: must have exactly one incoming edge and no outgoing
935
            if (first_dataflow.in_degree(*access) != 1 || first_dataflow.out_degree(*access) != 0) {
35✔
936
                return false;
×
937
            }
×
938
            auto& iedge = *first_dataflow.in_edges(*access).begin();
35✔
939
            if (iedge.type() != data_flow::MemletType::Computational) {
35✔
940
                return false;
×
941
            }
×
942
            if (producer_memlet != nullptr) {
35✔
943
                return false;
×
944
            }
×
945
            producer_memlet = &iedge;
35✔
946
        }
35✔
947
        if (producer_memlet == nullptr) {
35✔
948
            return false;
×
949
        }
×
950

951
        const auto& producer_subset = producer_memlet->subset();
35✔
952
        if (producer_subset.empty()) {
35✔
953
            return false;
×
954
        } else {
35✔
955
            producer_subsets.emplace(container, &producer_subset);
35✔
956
        }
35✔
957
    }
35✔
958

959
    FusionConsumerSubsetVisitor consumer_visitor(producer_subsets);
35✔
960
    bool abort = consumer_visitor.dispatch(*consumer_body_);
35✔
961
    if (abort) {
35✔
962
        return false;
1✔
963
    }
1✔
964

965
    for (auto [container, unique_subsets] : consumer_visitor.unique_subsets_per_container_) {
34✔
966
        auto& producer_subset = *producer_subsets.at(container);
34✔
967
        // For each unique consumer subset, solve index mappings and create a FusionCandidate
968
        // The direction determines which side's indvars are solved for
969
        for (const auto& consumer_subset : unique_subsets) {
38✔
970
            std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> mappings;
38✔
971

972
            if (direction_ == FusionDirection::ProducerIntoConsumer) {
38✔
973
                // Solve producer indvars in terms of consumer indvars
974
                mappings = solve_subsets(
34✔
975
                    producer_subset,
34✔
976
                    consumer_subset,
34✔
977
                    producer_loops_,
34✔
978
                    consumer_loops_,
34✔
979
                    producer_assumptions,
34✔
980
                    consumer_assumptions
34✔
981
                );
34✔
982
            } else {
34✔
983
                // ConsumerIntoProducer: solve consumer indvars in terms of producer indvars
984
                // Arguments are swapped, so invert the range check direction
985
                mappings = solve_subsets(
4✔
986
                    consumer_subset,
4✔
987
                    producer_subset,
4✔
988
                    consumer_loops_,
4✔
989
                    producer_loops_,
4✔
990
                    consumer_assumptions,
4✔
991
                    producer_assumptions,
4✔
992
                    true
4✔
993
                );
4✔
994
            }
4✔
995

996
            if (mappings.empty()) {
38✔
997
                return false;
5✔
998
            }
5✔
999

1000
            FusionCandidate candidate;
33✔
1001
            candidate.container = container;
33✔
1002
            candidate.consumer_subset = consumer_subset;
33✔
1003
            candidate.index_mappings = std::move(mappings);
33✔
1004

1005
            fusion_candidates_.push_back(candidate);
33✔
1006
        }
33✔
1007
    }
34✔
1008

1009
    // Reduction-branch safety: when fusing a parallel producer into a non-parallel
1010
    // (reduction) consumer, classify each fusion container into one of two sound patterns:
1011
    //   Case 1 (stream):     the container is NOT a consumer output and its consumer read
1012
    //                        depends on an inner sequential indvar -> it is produced and
1013
    //                        consumed element-by-element, so the producer is scalarized and
1014
    //                        inlined inside the reduction loop (e.g. softmax scale -> max).
1015
    //   Case 2 (init-hoist): the container IS a consumer output (the reduction accumulator)
1016
    //                        and its consumer read is loop-invariant w.r.t. every sequential
1017
    //                        indvar -> the producer is the accumulator's initial value and is
1018
    //                        hoisted to the reduction's outer parallel band, before the inner
1019
    //                        sequential loop (e.g. T = -INF preceding T = max(T, x)).
1020
    // Anything else (e.g. an accumulator whose read depends on the sequential indvar, or a
1021
    // streamed value that the consumer also writes) is unsafe and rejected. The two patterns
1022
    // require different placement in apply(), so all candidates must share one pattern.
1023
    if (consumer_reduction_branch) {
29✔
1024
        symbolic::SymbolSet sequential_indvars;
2✔
1025
        size_t first_sequential = consumer_loops_.size();
2✔
1026
        for (size_t li = 0; li < consumer_loops_.size(); ++li) {
8✔
1027
            if (dynamic_cast<structured_control_flow::Map*>(consumer_loops_[li]) == nullptr) {
6✔
1028
                sequential_indvars.insert(consumer_loops_[li]->indvar());
2✔
1029
                if (first_sequential == consumer_loops_.size()) {
2✔
1030
                    first_sequential = li;
2✔
1031
                }
2✔
1032
            }
2✔
1033
        }
6✔
1034
        if (sequential_indvars.empty()) {
2✔
1035
            return false;
×
1036
        }
×
1037
        bool any_stream = false;
2✔
1038
        bool any_init = false;
2✔
1039
        for (const auto& candidate : fusion_candidates_) {
2✔
1040
            bool depends_on_sequential = false;
2✔
1041
            for (const auto& dim : candidate.consumer_subset) {
2✔
1042
                for (const auto& atom : symbolic::atoms(dim)) {
8✔
1043
                    if (sequential_indvars.count(atom)) {
8✔
1044
                        depends_on_sequential = true;
1✔
1045
                        break;
1✔
1046
                    }
1✔
1047
                }
8✔
1048
                if (depends_on_sequential) {
2✔
1049
                    break;
1✔
1050
                }
1✔
1051
            }
2✔
1052

1053
            if (second_outputs.contains(candidate.container)) {
2✔
1054
                // Case 2 candidate: must be a loop-invariant accumulator init.
1055
                if (!allow_init_hoist_) {
1✔
1056
                    // Init-hoisting disabled for this run (reserved for the final
1057
                    // map-fusion pass so it does not fight loop distribution).
1058
                    return false;
×
1059
                }
×
1060
                if (depends_on_sequential) {
1✔
1061
                    return false;
×
1062
                }
×
1063
                any_init = true;
1✔
1064
            } else {
1✔
1065
                // Case 1 candidate: must be a streamed element.
1066
                if (!depends_on_sequential) {
1✔
1067
                    return false;
×
1068
                }
×
1069
                any_stream = true;
1✔
1070
            }
1✔
1071
        }
2✔
1072
        // Do not mix patterns in a single fusion.
1073
        if (any_init && any_stream) {
2✔
1074
            return false;
×
1075
        }
×
1076
        if (any_init) {
2✔
1077
            // Need an enclosing parallel band to host the hoisted init (the init must run
1078
            // once per accumulator element, outside the sequential reduction loop).
1079
            if (first_sequential == 0) {
1✔
1080
                return false;
×
1081
            }
×
1082
            init_hoist_ = true;
1✔
1083
            hoist_body_ = &consumer_loops_[first_sequential - 1]->root();
1✔
1084
        }
1✔
1085
    }
2✔
1086

1087
    // Criterion: At least one valid fusion candidate
1088
    return !fusion_candidates_.empty();
29✔
1089
}
29✔
1090

1091
void MapFusion::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
21✔
1092
    auto& sdfg = builder.subject();
21✔
1093

1094
    if (direction_ == FusionDirection::ProducerIntoConsumer) {
21✔
1095
        // Pattern 1 + Reverse Pattern 2: Inline producer blocks into consumer's read body
1096
        auto& first_dataflow = producer_block_->dataflow();
18✔
1097

1098
        // For each fusion candidate, create a temp and insert a producer block
1099
        std::vector<std::string> candidate_temps;
18✔
1100

1101
        for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
38✔
1102
            auto& candidate = fusion_candidates_[cand_idx];
20✔
1103

1104
            auto& container_type = sdfg.type(candidate.container);
20✔
1105
            types::Scalar tmp_type(container_type.primitive_type());
20✔
1106
            std::string temp_name;
20✔
1107
            if (!init_hoist_) {
20✔
1108
                // Case 1: scalarize the streamed element into a private temp.
1109
                temp_name = builder.find_new_name("_fused_tmp");
19✔
1110
                builder.add_container(temp_name, tmp_type);
19✔
1111
                candidate_temps.push_back(temp_name);
19✔
1112
            }
19✔
1113

1114
            // Insert the producer block at the beginning of the host sequence:
1115
            //  - Case 1 (stream):     consumer_body_ = innermost sequential (reduction) loop body.
1116
            //  - Case 2 (init-hoist): hoist_body_   = outer parallel-band body, before that loop.
1117
            auto& host_seq = init_hoist_ ? *hoist_body_ : *consumer_body_;
20✔
1118
            auto& first_child = host_seq.at(0).first;
20✔
1119
            control_flow::Assignments empty_assignments;
20✔
1120
            auto& new_block = builder.add_block_before(host_seq, first_child, empty_assignments);
20✔
1121
            structured_control_flow::Block* empty_block = nullptr;
20✔
1122

1123
            // Deep copy all nodes from producer block to new block
1124
            std::unordered_map<const data_flow::DataFlowNode*, data_flow::DataFlowNode*> node_mapping;
20✔
1125
            std::unordered_map<std::string, std::string> intermediate_renames;
20✔
1126
            for (auto& node : first_dataflow.nodes()) {
66✔
1127
                node_mapping[&node] = &builder.copy_node(new_block, node);
66✔
1128
                auto* copied = node_mapping[&node];
66✔
1129
                if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(copied)) {
66✔
1130
                    if (!init_hoist_ && access_node->data() == candidate.container) {
46✔
1131
                        // Case 1: redirect the producer's array write to the private scalar.
1132
                        access_node->data(temp_name);
19✔
1133
                    } else if (access_node->data() == first_map_.indvar()->get_name()) {
27✔
1134
                        // Determine the new expression for the index variable of the first map
1135
                        symbolic::Expression new_expr = SymEngine::null;
2✔
1136
                        for (auto& c : fusion_candidates_) {
2✔
1137
                            for (auto& [sym, expr] : c.index_mappings) {
2✔
1138
                                if (symbolic::eq(sym, first_map_.indvar())) {
2✔
1139
                                    new_expr = expr;
2✔
1140
                                    break;
2✔
1141
                                }
2✔
1142
                            }
2✔
1143
                            if (!new_expr.is_null()) {
2✔
1144
                                break;
2✔
1145
                            }
2✔
1146
                        }
2✔
1147

1148
                        if (new_expr.is_null() || symbolic::eq(new_expr, second_loop_.indvar())) {
2✔
1149
                            // Simple case: The new expression is simply the index variable of the second loop
1150
                            access_node->data(second_loop_.indvar()->get_name());
1✔
1151
                        } else {
1✔
1152
                            // Complex case: Add an empty block before the new block (if necessary) and store the
1153
                            // shifted index into a new temporary variable with an assignment. Then, replace the index
1154
                            // variable with the new temporary variable
1155
                            if (!empty_block) {
1✔
1156
                                empty_block = &builder.add_block_before(host_seq, new_block, empty_assignments);
1✔
1157
                            }
1✔
1158
                            auto new_index_name = builder.find_new_name();
1✔
1159
                            builder
1✔
1160
                                .add_container(new_index_name, builder.subject().type(second_loop_.indvar()->get_name()));
1✔
1161
                            host_seq.at(0).second.assignments().insert({symbolic::symbol(new_index_name), new_expr});
1✔
1162
                            access_node->data(new_index_name);
1✔
1163
                        }
1✔
1164
                    } else if (first_dataflow.in_degree(node) > 0 && first_dataflow.out_degree(node) > 0 &&
25✔
1165
                               dynamic_cast<const types::Scalar*>(&sdfg.type(access_node->data())) != nullptr) {
25✔
1166
                        // SSA Dataflow required to check for non-local use of the access node's container.
1167
                        // Intermediate access node (e.g. from a prior BlockFusion): clone
1168
                        // its container so each inlined copy gets its own private scalar
1169
                        auto it = intermediate_renames.find(access_node->data());
×
1170
                        if (it == intermediate_renames.end()) {
×
1171
                            std::string fresh = builder.find_new_name(access_node->data());
×
1172
                            builder.add_container(fresh, sdfg.type(access_node->data()));
×
1173
                            intermediate_renames[access_node->data()] = fresh;
×
1174
                        }
×
1175
                        access_node->data(intermediate_renames[access_node->data()]);
×
1176
                    }
×
1177
                }
46✔
1178
            }
66✔
1179

1180
            // Add memlets with index substitution (producer indvars → consumer expressions)
1181
            for (auto& edge : first_dataflow.edges()) {
46✔
1182
                auto& src_node = edge.src();
46✔
1183
                auto& dst_node = edge.dst();
46✔
1184

1185
                const types::IType* base_type = &edge.base_type();
46✔
1186
                data_flow::Subset new_subset;
46✔
1187
                for (const auto& dim : edge.subset()) {
46✔
1188
                    auto new_dim = dim;
42✔
1189
                    for (const auto& [pvar, mapping] : candidate.index_mappings) {
60✔
1190
                        new_dim = symbolic::subs(new_dim, pvar, mapping);
60✔
1191
                    }
60✔
1192
                    new_dim = symbolic::expand(new_dim);
42✔
1193
                    new_subset.push_back(new_dim);
42✔
1194
                }
42✔
1195

1196
                // Case 1: the producer's array write becomes a scalar write (empty subset).
1197
                // Case 2: keep the remapped array subset so the init writes the accumulator.
1198
                auto* dst_access = dynamic_cast<data_flow::AccessNode*>(&dst_node);
46✔
1199
                if (!init_hoist_ && dst_access != nullptr && dst_access->data() == candidate.container &&
46✔
1200
                    first_dataflow.in_degree(*dst_access) > 0) {
46✔
1201
                    new_subset.clear();
19✔
1202
                    base_type = &tmp_type;
19✔
1203
                }
19✔
1204

1205
                builder.add_memlet(
46✔
1206
                    new_block,
46✔
1207
                    *node_mapping[&src_node],
46✔
1208
                    edge.src_conn(),
46✔
1209
                    *node_mapping[&dst_node],
46✔
1210
                    edge.dst_conn(),
46✔
1211
                    new_subset,
46✔
1212
                    *base_type,
46✔
1213
                    edge.debug_info()
46✔
1214
                );
46✔
1215
            }
46✔
1216
        }
20✔
1217

1218
        // Case 1 only: rewrite consumer reads of the fused arrays to the scalar temps.
1219
        // Case 2 leaves the reduction body untouched (it keeps reading/writing the accumulator,
1220
        // now pre-initialized by the hoisted init block).
1221
        if (!init_hoist_) {
18✔
1222
            size_t num_producer_blocks = fusion_candidates_.size();
17✔
1223
            FusionConsumerUpdateVisitor update_visitor(builder, fusion_candidates_, candidate_temps);
17✔
1224
            update_visitor.dispatch_partial_sequence(*consumer_body_, num_producer_blocks, consumer_body_->size());
17✔
1225
        } else {
17✔
1226
            // Case 2: the hoisted init copy fully overwrites the accumulator before the
1227
            // reduction reads it, so the original init producer map is redundant. Unlike
1228
            // Case 1, the accumulator array stays live, so DCE would not reclaim it — remove
1229
            // the producer explicitly (mirrors how ConsumerIntoProducer removes its loop).
1230
            auto* parent = first_map_.get_parent();
1✔
1231
            auto* parent_seq = dynamic_cast<structured_control_flow::Sequence*>(parent);
1✔
1232
            if (parent_seq != nullptr) {
1✔
1233
                int idx = parent_seq->index(first_map_);
1✔
1234
                if (idx >= 0) {
1✔
1235
                    builder.remove_child(*parent_seq, static_cast<size_t>(idx));
1✔
1236
                }
1✔
1237
            }
1✔
1238
        }
1✔
1239

1240
    } else {
18✔
1241
        // ConsumerIntoProducer (Pattern 2): Inline consumer blocks into the producer's write body
1242
        // Modify the producer block in-place to write to a temp scalar, add a writeback block
1243
        // for the original array, then copy consumer blocks reading from the temp.
1244

1245
        std::vector<std::string> candidate_temps;
3✔
1246
        auto& producer_dataflow = producer_block_->dataflow();
3✔
1247

1248
        for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
6✔
1249
            auto& candidate = fusion_candidates_[cand_idx];
3✔
1250

1251
            auto& container_type = sdfg.type(candidate.container);
3✔
1252
            std::string temp_name = builder.find_new_name("_fused_tmp");
3✔
1253
            types::Scalar tmp_type(container_type.primitive_type());
3✔
1254
            builder.add_container(temp_name, tmp_type);
3✔
1255
            candidate_temps.push_back(temp_name);
3✔
1256

1257
            // Step 1: Modify the original producer block to write to _fused_tmp
1258
            data_flow::Subset original_write_subset;
3✔
1259
            for (auto& node : producer_dataflow.nodes()) {
5✔
1260
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
5✔
1261
                if (access == nullptr || access->data() != candidate.container) continue;
5✔
1262
                if (producer_dataflow.in_degree(*access) == 0) continue;
3✔
1263

1264
                // This is the write access node — save the original subset, then redirect
1265
                for (auto& in_edge : producer_dataflow.in_edges(*access)) {
3✔
1266
                    original_write_subset = in_edge.subset();
3✔
1267
                    in_edge.set_subset({});
3✔
1268
                    in_edge.set_base_type(tmp_type);
3✔
1269
                }
3✔
1270
                access->data(temp_name);
3✔
1271
                break;
3✔
1272
            }
3✔
1273

1274
            // Step 2: Add a writeback block: container[original_subset] = _fused_tmp
1275
            control_flow::Assignments empty_assignments;
3✔
1276
            auto& wb_block = builder.add_block_after(*producer_body_, *producer_block_, empty_assignments);
3✔
1277
            auto& wb_src = builder.add_access(wb_block, temp_name);
3✔
1278
            auto& wb_dst = builder.add_access(wb_block, candidate.container);
3✔
1279
            auto& wb_tasklet = builder.add_tasklet(wb_block, data_flow::TaskletCode::assign, "_out", {"_in"});
3✔
1280
            builder.add_computational_memlet(wb_block, wb_src, wb_tasklet, "_in", {});
3✔
1281
            builder.add_computational_memlet(wb_block, wb_tasklet, "_out", wb_dst, original_write_subset);
3✔
1282

1283
            // Step 3: Copy consumer blocks after the writeback block
1284
            structured_control_flow::ControlFlowNode* last_inserted = &wb_block;
3✔
1285

1286
            for (size_t i = 0; i < consumer_body_->size(); ++i) {
6✔
1287
                auto* consumer_block = dynamic_cast<structured_control_flow::Block*>(&consumer_body_->at(i).first);
3✔
1288
                if (consumer_block == nullptr) {
3✔
1289
                    continue;
×
1290
                }
×
1291

1292
                auto& consumer_dataflow = consumer_block->dataflow();
3✔
1293

1294
                // Check if this block reads from the fusion container
1295
                bool reads_container = false;
3✔
1296
                for (auto& node : consumer_dataflow.nodes()) {
11✔
1297
                    auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
11✔
1298
                    if (access != nullptr && access->data() == candidate.container &&
11✔
1299
                        consumer_dataflow.out_degree(*access) > 0) {
11✔
1300
                        reads_container = true;
3✔
1301
                        break;
3✔
1302
                    }
3✔
1303
                }
11✔
1304
                if (!reads_container) {
3✔
1305
                    continue;
×
1306
                }
×
1307

1308
                // Insert a new block after the last inserted block in the producer's body
1309
                auto& new_block = builder.add_block_after(*producer_body_, *last_inserted, empty_assignments);
3✔
1310
                structured_control_flow::Block* empty_block = nullptr;
3✔
1311

1312
                // Deep copy all nodes from consumer block
1313
                std::unordered_map<const data_flow::DataFlowNode*, data_flow::DataFlowNode*> node_mapping;
3✔
1314
                std::unordered_map<std::string, std::string> intermediate_renames;
3✔
1315
                for (auto& node : consumer_dataflow.nodes()) {
11✔
1316
                    node_mapping[&node] = &builder.copy_node(new_block, node);
11✔
1317
                    auto* copied = node_mapping[&node];
11✔
1318
                    if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(copied)) {
11✔
1319
                        if (access_node->data() == candidate.container) {
8✔
1320
                            // Only rename read access nodes to temp; keep write access nodes
1321
                            // pointing to the original container
1322
                            if (consumer_dataflow.in_degree(node) == 0) {
4✔
1323
                                access_node->data(temp_name);
3✔
1324
                            }
3✔
1325
                        } else if (consumer_dataflow.in_degree(node) > 0 && consumer_dataflow.out_degree(node) > 0 &&
4✔
1326
                                   dynamic_cast<const types::Scalar*>(&sdfg.type(access_node->data())) != nullptr) {
4✔
1327
                            // SSA Dataflow required to check for non-local use of the access node's container.
1328
                            // Intermediate access node (e.g. from a prior BlockFusion): clone
1329
                            // its container so each inlined copy gets its own private scalar
1330
                            auto it = intermediate_renames.find(access_node->data());
×
1331
                            if (it == intermediate_renames.end()) {
×
1332
                                std::string fresh = builder.find_new_name(access_node->data());
×
1333
                                builder.add_container(fresh, sdfg.type(access_node->data()));
×
1334
                                intermediate_renames[access_node->data()] = fresh;
×
1335
                            }
×
1336
                            access_node->data(intermediate_renames[access_node->data()]);
×
1337
                        }
×
1338
                        if (access_node->data() == second_loop_.indvar()->get_name() &&
8✔
1339
                            consumer_dataflow.in_degree(node) == 0) {
8✔
1340
                            // Determine the new expression for the index variable of the second loop
1341
                            symbolic::Expression new_expr = SymEngine::null;
×
1342
                            for (auto& c : fusion_candidates_) {
×
1343
                                for (auto& [sym, expr] : c.index_mappings) {
×
1344
                                    if (symbolic::eq(sym, second_loop_.indvar())) {
×
1345
                                        new_expr = expr;
×
1346
                                        break;
×
1347
                                    }
×
1348
                                }
×
1349
                                if (!new_expr.is_null()) {
×
1350
                                    break;
×
1351
                                }
×
1352
                            }
×
1353

1354
                            if (new_expr.is_null() || symbolic::eq(new_expr, first_map_.indvar())) {
×
1355
                                // Simple case: The new expression is simply the index variable of the first map
1356
                                access_node->data(first_map_.indvar()->get_name());
×
1357
                            } else {
×
1358
                                // Complex case: Add an empty block before the new block (if necessary) and store the
1359
                                // shifted index into a new temporary variable with an assignment. Then, replace the
1360
                                // index variable with the new temporary variable
1361
                                if (!empty_block) {
×
1362
                                    empty_block =
×
1363
                                        &builder.add_block_before(*producer_body_, new_block, empty_assignments);
×
1364
                                }
×
1365
                                auto new_index_name = builder.find_new_name();
×
1366
                                builder.add_container(
×
1367
                                    new_index_name, builder.subject().type(first_map_.indvar()->get_name())
×
1368
                                );
×
1369
                                producer_body_->at(0)
×
1370
                                    .second.assignments()
×
1371
                                    .insert({symbolic::symbol(new_index_name), new_expr});
×
1372
                                access_node->data(new_index_name);
×
1373
                            }
×
1374
                        }
×
1375
                    }
8✔
1376
                }
11✔
1377

1378
                // Add memlets with index substitution (consumer indvars → producer expressions)
1379
                for (auto& edge : consumer_dataflow.edges()) {
8✔
1380
                    auto& src_node = edge.src();
8✔
1381
                    auto& dst_node = edge.dst();
8✔
1382

1383
                    const types::IType* base_type = &edge.base_type();
8✔
1384
                    data_flow::Subset new_subset;
8✔
1385
                    for (const auto& dim : edge.subset()) {
9✔
1386
                        auto new_dim = dim;
9✔
1387
                        for (const auto& [cvar, mapping] : candidate.index_mappings) {
13✔
1388
                            new_dim = symbolic::subs(new_dim, cvar, mapping);
13✔
1389
                        }
13✔
1390
                        new_dim = symbolic::expand(new_dim);
9✔
1391
                        new_subset.push_back(new_dim);
9✔
1392
                    }
9✔
1393

1394
                    // For read edges from temp scalar, use empty subset
1395
                    auto* src_access = dynamic_cast<data_flow::AccessNode*>(&src_node);
8✔
1396
                    if (src_access != nullptr && src_access->data() == candidate.container &&
8✔
1397
                        consumer_dataflow.in_degree(*src_access) == 0) {
8✔
1398
                        new_subset.clear();
3✔
1399
                        base_type = &tmp_type;
3✔
1400
                    }
3✔
1401

1402
                    builder.add_memlet(
8✔
1403
                        new_block,
8✔
1404
                        *node_mapping[&src_node],
8✔
1405
                        edge.src_conn(),
8✔
1406
                        *node_mapping[&dst_node],
8✔
1407
                        edge.dst_conn(),
8✔
1408
                        new_subset,
8✔
1409
                        *base_type,
8✔
1410
                        edge.debug_info()
8✔
1411
                    );
8✔
1412
                }
8✔
1413

1414
                last_inserted = &new_block;
3✔
1415
            }
3✔
1416
        }
3✔
1417

1418
        // Remove the consumer loop
1419
        auto* parent = second_loop_.get_parent();
3✔
1420
        auto* parent_seq = dynamic_cast<structured_control_flow::Sequence*>(parent);
3✔
1421
        if (parent_seq != nullptr) {
3✔
1422
            int idx = parent_seq->index(second_loop_);
3✔
1423
            if (idx >= 0) {
3✔
1424
                builder.remove_child(*parent_seq, static_cast<size_t>(idx));
3✔
1425
            }
3✔
1426
        }
3✔
1427
    }
3✔
1428

1429
    analysis_manager.invalidate_all();
21✔
1430
    applied_ = true;
21✔
1431
}
21✔
1432

1433
void MapFusion::to_json(nlohmann::json& j) const {
1✔
1434
    j["transformation_type"] = this->name();
1✔
1435
    j["parameters"] = nlohmann::json::object();
1✔
1436

1437
    serializer::JSONSerializer ser_flat(false);
1✔
1438
    j["subgraph"] = nlohmann::json::object();
1✔
1439
    j["subgraph"]["0"] = nlohmann::json::object();
1✔
1440
    ser_flat.serialize_node(j["subgraph"]["0"], first_map_);
1✔
1441

1442
    j["subgraph"]["1"] = nlohmann::json::object();
1✔
1443
    ser_flat.serialize_node(j["subgraph"]["1"], second_loop_);
1✔
1444
}
1✔
1445

1446
MapFusion MapFusion::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& desc) {
1✔
1447
    auto first_map_id = desc["subgraph"]["0"]["element_id"].get<size_t>();
1✔
1448
    auto second_loop_id = desc["subgraph"]["1"]["element_id"].get<size_t>();
1✔
1449

1450
    auto first_element = builder.find_element_by_id(first_map_id);
1✔
1451
    auto second_element = builder.find_element_by_id(second_loop_id);
1✔
1452

1453
    if (first_element == nullptr) {
1✔
1454
        throw InvalidTransformationDescriptionException("Element with ID " + std::to_string(first_map_id) + " not found.");
×
UNCOV
1455
    }
×
1456
    if (second_element == nullptr) {
1✔
1457
        throw InvalidTransformationDescriptionException(
×
1458
            "Element with ID " + std::to_string(second_loop_id) + " not found."
×
1459
        );
×
UNCOV
1460
    }
×
1461

1462
    auto* first_map = dynamic_cast<structured_control_flow::Map*>(first_element);
1✔
1463
    auto* second_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(second_element);
1✔
1464

1465
    if (first_map == nullptr) {
1✔
1466
        throw InvalidTransformationDescriptionException(
×
1467
            "Element with ID " + std::to_string(first_map_id) + " is not a Map."
×
1468
        );
×
UNCOV
1469
    }
×
1470
    if (second_loop == nullptr) {
1✔
1471
        throw InvalidTransformationDescriptionException(
×
1472
            "Element with ID " + std::to_string(second_loop_id) + " is not a StructuredLoop."
×
1473
        );
×
UNCOV
1474
    }
×
1475

1476
    return MapFusion(*first_map, *second_loop);
1✔
1477
}
1✔
1478

1479
} // namespace transformations
1480
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc