• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 24154538820

08 Apr 2026 07:33PM UTC coverage: 64.986% (+0.1%) from 64.871%
24154538820

push

github

web-flow
Merge pull request #660 from daisytuner/consumer-map-fusion

adds consumer-based map fusion

332 of 379 new or added lines in 1 file covered. (87.6%)

1 existing line in 1 file now uncovered.

29237 of 44990 relevant lines covered (64.99%)

601.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.34
/opt/src/transformations/map_fusion.cpp
1
#include "sdfg/transformations/map_fusion.h"
2

3
#include <isl/ctx.h>
4
#include <isl/map.h>
5
#include <isl/options.h>
6
#include <isl/set.h>
7
#include <isl/space.h>
8
#include <symengine/solve.h>
9
#include "sdfg/analysis/arguments_analysis.h"
10
#include "sdfg/analysis/loop_analysis.h"
11
#include "sdfg/analysis/scope_analysis.h"
12
#include "sdfg/analysis/users.h"
13

14
#include "sdfg/analysis/assumptions_analysis.h"
15
#include "sdfg/control_flow/interstate_edge.h"
16
#include "sdfg/data_flow/data_flow_graph.h"
17
#include "sdfg/structured_control_flow/block.h"
18
#include "sdfg/structured_control_flow/for.h"
19
#include "sdfg/symbolic/delinearization.h"
20
#include "sdfg/symbolic/utils.h"
21

22
namespace sdfg {
23
namespace transformations {
24

25
MapFusion::MapFusion(structured_control_flow::Map& first_map, structured_control_flow::StructuredLoop& second_loop)
26
    : first_map_(first_map), second_loop_(second_loop) {}
33✔
27

28
std::string MapFusion::name() const { return "MapFusion"; }
2✔
29

30
std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> MapFusion::solve_subsets(
31
    const data_flow::Subset& producer_subset,
32
    const data_flow::Subset& consumer_subset,
33
    const std::vector<structured_control_flow::StructuredLoop*>& producer_loops,
34
    const std::vector<structured_control_flow::StructuredLoop*>& consumer_loops,
35
    const symbolic::Assumptions& producer_assumptions,
36
    const symbolic::Assumptions& consumer_assumptions,
37
    bool invert_range_check
38
) {
29✔
39
    // Delinearize subsets to recover multi-dimensional structure from linearized accesses
40
    // e.g. T[i*N + j] with assumptions on bounds -> T[i, j]
41
    auto producer_sub = producer_subset;
29✔
42
    if (producer_sub.size() == 1) {
29✔
43
        auto producer_result = symbolic::delinearize(producer_sub.at(0), producer_assumptions);
20✔
44
        if (producer_result.success) {
20✔
45
            producer_sub = producer_result.indices;
20✔
46
        }
20✔
47
    }
20✔
48
    auto consumer_sub = consumer_subset;
29✔
49
    if (consumer_sub.size() == 1) {
29✔
50
        auto consumer_result = symbolic::delinearize(consumer_sub.at(0), consumer_assumptions);
20✔
51
        if (consumer_result.success) {
20✔
52
            consumer_sub = consumer_result.indices;
20✔
53
        }
20✔
54
    }
20✔
55

56
    // Subset dimensions must match
57
    if (producer_sub.size() != consumer_sub.size()) {
29✔
58
        return {};
×
59
    }
×
60
    if (producer_sub.empty()) {
29✔
61
        return {};
×
62
    }
×
63

64
    // Extract producer indvars
65
    SymEngine::vec_sym producer_vars;
29✔
66
    for (auto* loop : producer_loops) {
38✔
67
        producer_vars.push_back(SymEngine::rcp_static_cast<const SymEngine::Symbol>(loop->indvar()));
38✔
68
    }
38✔
69

70
    // Step 1: Solve the linear equation system using SymEngine
71
    // System: producer_sub[d] - consumer_sub[d] = 0, for each dimension d
72
    // Solve for producer_vars in terms of consumer_vars and parameters
73
    SymEngine::vec_basic equations;
29✔
74
    for (size_t d = 0; d < producer_sub.size(); ++d) {
67✔
75
        equations.push_back(symbolic::sub(producer_sub.at(d), consumer_sub.at(d)));
38✔
76
    }
38✔
77

78
    // Need exactly as many equations as unknowns for a unique solution.
79
    // Underdetermined systems (e.g. linearized access with multiple loop vars)
80
    // cannot be uniquely solved and would crash linsolve.
81
    if (equations.size() != producer_vars.size()) {
29✔
82
        return {};
×
83
    }
×
84

85
    SymEngine::vec_basic solution;
29✔
86
    try {
29✔
87
        solution = SymEngine::linsolve(equations, producer_vars);
29✔
88
    } catch (...) {
29✔
89
        return {};
×
90
    }
×
91
    if (solution.size() != producer_vars.size()) {
29✔
92
        return {};
×
93
    }
×
94
    // Build consumer var set for atom validation
95
    symbolic::SymbolSet consumer_var_set;
29✔
96
    for (auto* loop : consumer_loops) {
38✔
97
        consumer_var_set.insert(loop->indvar());
38✔
98
    }
38✔
99

100
    std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> mappings;
29✔
101
    for (size_t i = 0; i < producer_vars.size(); ++i) {
67✔
102
        auto& sol = solution[i];
38✔
103

104
        // Check for invalid solutions
105
        if (SymEngine::is_a<SymEngine::NaN>(*sol) || SymEngine::is_a<SymEngine::Infty>(*sol)) {
38✔
106
            return {};
×
107
        }
×
108

109
        // Validate that solution atoms are consumer vars or parameters
110
        for (const auto& atom : symbolic::atoms(sol)) {
40✔
111
            if (consumer_var_set.count(atom)) {
40✔
112
                continue;
40✔
113
            }
40✔
114
            bool is_param = false;
×
115
            auto it = consumer_assumptions.find(atom);
×
116
            if (it != consumer_assumptions.end() && it->second.constant()) {
×
117
                is_param = true;
×
118
            }
×
119
            if (!is_param) {
×
120
                it = producer_assumptions.find(atom);
×
121
                if (it != producer_assumptions.end() && it->second.constant()) {
×
122
                    is_param = true;
×
123
                }
×
124
            }
×
125
            if (!is_param) {
×
126
                return {};
×
127
            }
×
128
        }
×
129

130
        mappings.push_back({symbolic::symbol(producer_vars[i]->get_name()), symbolic::expand(sol)});
38✔
131
    }
38✔
132
    // Step 2: ISL integrality validation via map composition
133
    // Build an unconstrained producer access map (no domain bounds on producer vars).
134
    // In map fusion, the producer's computation is inlined into the consumer, so
135
    // the producer's original iteration domain is irrelevant. We only need to verify
136
    // that the equation system has an INTEGER solution for every consumer point.
137
    symbolic::Assumptions unconstrained_producer;
29✔
138
    for (auto* loop : producer_loops) {
38✔
139
        symbolic::Assumption a(loop->indvar());
38✔
140
        a.constant(false);
38✔
141
        unconstrained_producer[loop->indvar()] = a;
38✔
142
    }
38✔
143
    for (const auto& [sym, assump] : producer_assumptions) {
76✔
144
        if (assump.constant() && unconstrained_producer.find(sym) == unconstrained_producer.end()) {
76✔
145
            unconstrained_producer[sym] = assump;
38✔
146
        }
38✔
147
    }
76✔
148

149
    std::string producer_map_str = symbolic::expression_to_map_str(producer_sub, unconstrained_producer);
29✔
150
    // Build consumer access map with full domain constraints
151
    std::string consumer_map_str = symbolic::expression_to_map_str(consumer_sub, consumer_assumptions);
29✔
152

153
    isl_ctx* ctx = isl_ctx_alloc();
29✔
154
    isl_options_set_on_error(ctx, ISL_ON_ERROR_CONTINUE);
29✔
155

156
    isl_map* producer_map = isl_map_read_from_str(ctx, producer_map_str.c_str());
29✔
157
    isl_map* consumer_map = isl_map_read_from_str(ctx, consumer_map_str.c_str());
29✔
158

159
    if (!producer_map || !consumer_map) {
29✔
160
        if (producer_map) isl_map_free(producer_map);
×
161
        if (consumer_map) isl_map_free(consumer_map);
×
162
        isl_ctx_free(ctx);
×
163
        return {};
×
164
    }
×
165

166
    // Align parameters between the two maps
167
    isl_space* params_p = isl_space_params(isl_map_get_space(producer_map));
29✔
168
    isl_space* params_c = isl_space_params(isl_map_get_space(consumer_map));
29✔
169
    isl_space* unified = isl_space_align_params(isl_space_copy(params_p), isl_space_copy(params_c));
29✔
170
    isl_space_free(params_p);
29✔
171
    isl_space_free(params_c);
29✔
172

173
    producer_map = isl_map_align_params(producer_map, isl_space_copy(unified));
29✔
174
    consumer_map = isl_map_align_params(consumer_map, isl_space_copy(unified));
29✔
175

176
    // Save consumer domain before consuming consumer_map in composition
177
    isl_set* consumer_domain = isl_map_domain(isl_map_copy(consumer_map));
29✔
178

179
    // Compute composition: consumer_access ∘ inverse(producer_access)
180
    // This checks whether the equation system producer_subset = consumer_subset
181
    // has an integer solution for each consumer domain point.
182
    isl_map* producer_inverse = isl_map_reverse(producer_map);
29✔
183
    isl_map* composition = isl_map_apply_range(consumer_map, producer_inverse);
29✔
184

185
    // Check single-valuedness: each consumer point maps to at most one producer point
186
    bool single_valued = isl_map_is_single_valued(composition) == isl_bool_true;
29✔
187

188
    // Check domain coverage: every consumer point has a valid integer mapping
189
    isl_set* comp_domain = isl_map_domain(composition);
29✔
190

191
    bool domain_covered = isl_set_is_subset(consumer_domain, comp_domain) == isl_bool_true;
29✔
192

193
    isl_set_free(comp_domain);
29✔
194
    isl_set_free(consumer_domain);
29✔
195

196
    // Step 3: Verify producer write range covers consumer read range.
197
    // The producer only writes a subset of the array if its loops have restricted bounds.
198
    // Fusion is invalid if the consumer reads elements the producer never writes.
199
    bool range_covered = false;
29✔
200
    if (single_valued && domain_covered) {
29✔
201
        std::string constrained_producer_map_str = symbolic::expression_to_map_str(producer_sub, producer_assumptions);
27✔
202
        isl_map* constrained_producer = isl_map_read_from_str(ctx, constrained_producer_map_str.c_str());
27✔
203
        isl_map* consumer_map_copy = isl_map_read_from_str(ctx, consumer_map_str.c_str());
27✔
204

205
        if (constrained_producer && consumer_map_copy) {
27✔
206
            constrained_producer = isl_map_align_params(constrained_producer, isl_space_copy(unified));
27✔
207
            consumer_map_copy = isl_map_align_params(consumer_map_copy, isl_space_copy(unified));
27✔
208

209
            isl_set* producer_range = isl_map_range(constrained_producer);
27✔
210
            isl_set* consumer_range = isl_map_range(consumer_map_copy);
27✔
211

212
            // When arguments are swapped (ConsumerIntoProducer), the "producer"/"consumer"
213
            // labels are inverted. Flip the subset check to always verify:
214
            // actual_consumer_read_range ⊆ actual_producer_write_range
215
            if (invert_range_check) {
27✔
216
                range_covered = isl_set_is_subset(producer_range, consumer_range) == isl_bool_true;
2✔
217
            } else {
25✔
218
                range_covered = isl_set_is_subset(consumer_range, producer_range) == isl_bool_true;
25✔
219
            }
25✔
220

221
            isl_set_free(producer_range);
27✔
222
            isl_set_free(consumer_range);
27✔
223
        } else {
27✔
224
            if (constrained_producer) isl_map_free(constrained_producer);
×
225
            if (consumer_map_copy) isl_map_free(consumer_map_copy);
×
226
        }
×
227
    }
27✔
228

229
    isl_space_free(unified);
29✔
230
    isl_ctx_free(ctx);
29✔
231

232
    if (!single_valued || !domain_covered || !range_covered) {
29✔
233
        return {};
5✔
234
    }
5✔
235

236
    return mappings;
24✔
237
}
29✔
238

239
bool MapFusion::find_write_location(
240
    structured_control_flow::StructuredLoop& loop,
241
    const std::string& container,
242
    std::vector<structured_control_flow::StructuredLoop*>& loops,
243
    structured_control_flow::Sequence*& body,
244
    structured_control_flow::Block*& block
245
) {
4✔
246
    loops.push_back(&loop);
4✔
247
    auto& seq = loop.root();
4✔
248

249
    for (size_t i = 0; i < seq.size(); ++i) {
10✔
250
        auto& child = seq.at(i).first;
6✔
251

252
        if (auto* blk = dynamic_cast<structured_control_flow::Block*>(&child)) {
6✔
253
            // Check if this block writes to the container
254
            auto& dataflow = blk->dataflow();
4✔
255
            for (auto& node : dataflow.nodes()) {
14✔
256
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
14✔
257
                if (access == nullptr || access->data() != container) {
14✔
258
                    continue;
12✔
259
                }
12✔
260
                // Write access: has incoming edges (sink node)
261
                if (dataflow.in_degree(*access) > 0 && dataflow.out_degree(*access) == 0) {
2✔
262
                    if (block != nullptr) {
2✔
263
                        // Multiple write blocks found — ambiguous
NEW
264
                        return false;
×
NEW
265
                    }
×
266
                    body = &seq;
2✔
267
                    block = blk;
2✔
268
                }
2✔
269
            }
2✔
270
        } else if (auto* nested_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(&child)) {
4✔
271
            if (!find_write_location(*nested_loop, container, loops, body, block)) {
2✔
NEW
272
                return false;
×
NEW
273
            }
×
274
            // If we didn't find the write in this subtree, pop the loop back off
275
            if (loops.back() != &loop && block == nullptr) {
2✔
276
                // The recursive call already popped — but we need to check
NEW
277
            }
×
278
        }
2✔
279
    }
6✔
280

281
    // If we didn't find the write in this subtree, remove this loop from the chain
282
    if (block == nullptr) {
4✔
NEW
283
        loops.pop_back();
×
NEW
284
    }
×
285

286
    return true;
4✔
287
}
4✔
288

289
bool MapFusion::find_read_location(
290
    structured_control_flow::StructuredLoop& loop,
291
    const std::string& container,
292
    std::vector<structured_control_flow::StructuredLoop*>& loops,
293
    structured_control_flow::Sequence*& body
294
) {
2✔
295
    loops.push_back(&loop);
2✔
296
    auto& seq = loop.root();
2✔
297

298
    for (size_t i = 0; i < seq.size(); ++i) {
5✔
299
        auto& child = seq.at(i).first;
3✔
300

301
        if (auto* blk = dynamic_cast<structured_control_flow::Block*>(&child)) {
3✔
302
            // Check if this block reads from the container
303
            auto& dataflow = blk->dataflow();
2✔
304
            for (auto& node : dataflow.nodes()) {
8✔
305
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
8✔
306
                if (access == nullptr || access->data() != container) {
8✔
307
                    continue;
7✔
308
                }
7✔
309
                // Read access: has outgoing edges (source node)
310
                if (dataflow.in_degree(*access) == 0 && dataflow.out_degree(*access) > 0) {
1✔
311
                    if (body != nullptr && body != &seq) {
1✔
312
                        // Reads at different sequence levels — ambiguous
NEW
313
                        return false;
×
NEW
314
                    }
×
315
                    body = &seq;
1✔
316
                }
1✔
317
            }
1✔
318
        } else if (auto* nested_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(&child)) {
2✔
319
            if (!find_read_location(*nested_loop, container, loops, body)) {
1✔
NEW
320
                return false;
×
NEW
321
            }
×
322
        }
1✔
323
    }
3✔
324

325
    // If we didn't find any reads in this subtree, remove this loop from the chain
326
    if (body == nullptr) {
2✔
NEW
327
        loops.pop_back();
×
NEW
328
    }
×
329

330
    return true;
2✔
331
}
2✔
332

333
bool MapFusion::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
31✔
334
    fusion_candidates_.clear();
31✔
335

336
    // no use in fusing empty loops. Also presumed to not be empty further down
337
    if (first_map_.root().size() == 0 || second_loop_.root().size() == 0) {
31✔
338
        return false;
×
339
    }
×
340

341
    // Criterion: Get parent scope and verify both loops are sequential children
342
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
31✔
343
    auto* first_parent = scope_analysis.parent_scope(&first_map_);
31✔
344
    auto* second_parent = scope_analysis.parent_scope(&second_loop_);
31✔
345
    if (first_parent == nullptr || second_parent == nullptr) {
31✔
346
        return false;
×
347
    }
×
348
    if (first_parent != second_parent) {
31✔
349
        return false;
×
350
    }
×
351

352
    auto* parent_sequence = dynamic_cast<structured_control_flow::Sequence*>(first_parent);
31✔
353
    if (parent_sequence == nullptr) {
31✔
354
        return false;
×
355
    }
×
356

357
    int first_index = parent_sequence->index(first_map_);
31✔
358
    int second_index = parent_sequence->index(second_loop_);
31✔
359
    if (first_index == -1 || second_index == -1) {
31✔
360
        return false;
×
361
    }
×
362
    if (second_index != first_index + 1) {
31✔
363
        return false;
1✔
364
    }
1✔
365

366
    // Criterion: Transition between maps should have no assignments
367
    auto& transition = parent_sequence->at(first_index).second;
30✔
368
    if (!transition.empty()) {
30✔
369
        return false;
×
370
    }
×
371
    // Determine fusion pattern based on nesting properties
372
    auto& loop_analysis = analysis_manager.get<analysis::LoopAnalysis>();
30✔
373
    auto first_loop_info = loop_analysis.loop_info(&first_map_);
30✔
374
    auto second_loop_info = loop_analysis.loop_info(&second_loop_);
30✔
375

376
    bool first_nested = first_loop_info.is_perfectly_nested;
30✔
377
    bool second_nested = second_loop_info.is_perfectly_nested;
30✔
378

379
    // Both non-perfectly-nested: not supported
380
    if (!first_nested && !second_nested) {
30✔
381
        return false;
1✔
382
    }
1✔
383

384
    if (first_nested && second_nested) {
29✔
385
        // Pattern 1: Both perfectly nested — producer into consumer (original path)
386
        direction_ = FusionDirection::ProducerIntoConsumer;
26✔
387
    } else if (!first_nested && second_nested) {
26✔
388
        // Pattern 2: Producer non-perfectly-nested, consumer perfectly nested
389
        direction_ = FusionDirection::ConsumerIntoProducer;
2✔
390
    } else {
2✔
391
        // Reverse Pattern 2: Producer perfectly nested, consumer non-perfectly-nested
392
        direction_ = FusionDirection::ProducerIntoConsumer;
1✔
393
    }
1✔
394

395
    // The side being inlined must be all-parallel (all Maps) so iterations can be reordered.
396
    // ProducerIntoConsumer: producer is replicated at each consumer site — producer must be all-parallel.
397
    // ConsumerIntoProducer: consumer is reordered into producer's nest — consumer must be all-parallel.
398
    if (direction_ == FusionDirection::ProducerIntoConsumer) {
29✔
399
        if (!first_loop_info.is_perfectly_parallel) {
27✔
NEW
400
            return false;
×
NEW
401
        }
×
402
    } else {
27✔
403
        if (!second_loop_info.is_perfectly_parallel) {
2✔
NEW
404
            return false;
×
NEW
405
        }
×
406
    }
2✔
407

408
    // Locate producer write point
409
    producer_loops_.clear();
29✔
410
    producer_body_ = nullptr;
29✔
411
    producer_block_ = nullptr;
29✔
412

413
    if (first_nested) {
29✔
414
        // Perfectly nested: walk the at(0).first chain
415
        producer_loops_.push_back(&first_map_);
27✔
416
        producer_body_ = &first_map_.root();
27✔
417
        structured_control_flow::ControlFlowNode* node = &first_map_.root().at(0).first;
27✔
418
        while (auto* nested = dynamic_cast<structured_control_flow::StructuredLoop*>(node)) {
35✔
419
            producer_loops_.push_back(nested);
8✔
420
            producer_body_ = &nested->root();
8✔
421
            node = &nested->root().at(0).first;
8✔
422
        }
8✔
423
        producer_block_ = dynamic_cast<structured_control_flow::Block*>(node);
27✔
424
        if (producer_block_ == nullptr) {
27✔
NEW
425
            return false;
×
NEW
426
        }
×
427
    } else {
27✔
428
        // Non-perfectly-nested: search recursively for the write block
429
        // We need to know which containers to look for, but we don't know them yet.
430
        // Defer write location search until after fusion_containers are identified.
431
    }
2✔
432

433
    // Locate consumer read point
434
    consumer_loops_.clear();
29✔
435
    consumer_body_ = nullptr;
29✔
436

437
    if (second_nested) {
29✔
438
        // Perfectly nested: walk the at(0).first chain
439
        consumer_loops_.push_back(&second_loop_);
28✔
440
        consumer_body_ = &second_loop_.root();
28✔
441
        structured_control_flow::ControlFlowNode* node = &second_loop_.root().at(0).first;
28✔
442
        while (auto* nested = dynamic_cast<structured_control_flow::StructuredLoop*>(node)) {
36✔
443
            consumer_loops_.push_back(nested);
8✔
444
            consumer_body_ = &nested->root();
8✔
445
            node = &nested->root().at(0).first;
8✔
446
        }
8✔
447
    } else {
28✔
448
        // Non-perfectly-nested: defer read location search until after fusion_containers are identified.
449
    }
1✔
450

451
    // Get arguments analysis to identify inputs/outputs of each loop
452
    auto& arguments_analysis = analysis_manager.get<analysis::ArgumentsAnalysis>();
29✔
453
    auto first_args = arguments_analysis.arguments(analysis_manager, first_map_);
29✔
454
    auto second_args = arguments_analysis.arguments(analysis_manager, second_loop_);
29✔
455

456
    std::unordered_set<std::string> first_inputs;
29✔
457
    std::unordered_set<std::string> first_outputs;
29✔
458
    for (const auto& [name, arg] : first_args) {
98✔
459
        if (arg.is_output) {
98✔
460
            first_outputs.insert(name);
30✔
461
        }
30✔
462
        if (arg.is_input) {
98✔
463
            first_inputs.insert(name);
98✔
464
        }
98✔
465
    }
98✔
466

467
    std::unordered_set<std::string> fusion_containers;
29✔
468
    for (const auto& [name, arg] : second_args) {
98✔
469
        if (first_outputs.contains(name)) {
98✔
470
            if (arg.is_output) {
27✔
471
                return false;
×
472
            }
×
473
            if (arg.is_input) {
27✔
474
                fusion_containers.insert(name);
27✔
475
            }
27✔
476
        }
27✔
477
        if (first_inputs.contains(name) && arg.is_output) {
98✔
478
            return false;
1✔
479
        }
1✔
480
    }
98✔
481
    if (fusion_containers.empty()) {
28✔
482
        return false;
1✔
483
    }
1✔
484

485
    // Now that we know the fusion containers, resolve deferred locations
486
    if (!first_nested) {
27✔
487
        // Non-perfectly-nested producer: find write location for the first fusion container
488
        // All fusion containers must be written at the same block for this to work
489
        for (const auto& container : fusion_containers) {
2✔
490
            std::vector<structured_control_flow::StructuredLoop*> write_loops;
2✔
491
            structured_control_flow::Sequence* write_body = nullptr;
2✔
492
            structured_control_flow::Block* write_block = nullptr;
2✔
493

494
            if (!find_write_location(first_map_, container, write_loops, write_body, write_block)) {
2✔
NEW
495
                return false;
×
NEW
496
            }
×
497
            if (write_block == nullptr) {
2✔
NEW
498
                return false;
×
NEW
499
            }
×
500

501
            if (producer_block_ == nullptr) {
2✔
502
                // First container: set the locations
503
                producer_loops_ = write_loops;
2✔
504
                producer_body_ = write_body;
2✔
505
                producer_block_ = write_block;
2✔
506
            } else {
2✔
507
                // Subsequent containers must be in the same block
NEW
508
                if (write_block != producer_block_) {
×
NEW
509
                    return false;
×
NEW
510
                }
×
NEW
511
            }
×
512
        }
2✔
513
    }
2✔
514

515
    if (!second_nested) {
27✔
516
        // Non-perfectly-nested consumer: find read location for the first fusion container
517
        // All fusion containers must be read at the same sequence for this to work
518
        for (const auto& container : fusion_containers) {
1✔
519
            std::vector<structured_control_flow::StructuredLoop*> read_loops;
1✔
520
            structured_control_flow::Sequence* read_body = nullptr;
1✔
521

522
            if (!find_read_location(second_loop_, container, read_loops, read_body)) {
1✔
NEW
523
                return false;
×
NEW
524
            }
×
525
            if (read_body == nullptr) {
1✔
NEW
526
                return false;
×
NEW
527
            }
×
528

529
            if (consumer_body_ == nullptr) {
1✔
530
                // First container: set the locations
531
                consumer_loops_ = read_loops;
1✔
532
                consumer_body_ = read_body;
1✔
533
            } else {
1✔
534
                // Subsequent containers must be at the same sequence
NEW
535
                if (read_body != consumer_body_) {
×
NEW
536
                    return false;
×
NEW
537
                }
×
NEW
538
            }
×
539
        }
1✔
540
    }
1✔
541

542
    // Get assumptions for the resolved write/read locations
543
    auto& assumptions_analysis = analysis_manager.get<analysis::AssumptionsAnalysis>();
27✔
544
    auto& producer_assumptions = assumptions_analysis.get(*producer_block_);
27✔
545
    auto& consumer_assumptions = assumptions_analysis.get(consumer_body_->at(0).first);
27✔
546

547
    // For each fusion container, find the producer memlet and collect unique consumer subsets
548
    auto& first_dataflow = producer_block_->dataflow();
27✔
549
    for (const auto& container : fusion_containers) {
27✔
550
        // Find unique producer write in first map
551
        data_flow::Memlet* producer_memlet = nullptr;
27✔
552

553
        for (auto& node : first_dataflow.nodes()) {
85✔
554
            auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
85✔
555
            if (access == nullptr || access->data() != container) {
85✔
556
                continue;
58✔
557
            }
58✔
558
            if (first_dataflow.in_degree(*access) != 1 || first_dataflow.out_degree(*access) != 0) {
27✔
559
                return false;
×
560
            }
×
561
            auto& iedge = *first_dataflow.in_edges(*access).begin();
27✔
562
            if (iedge.type() != data_flow::MemletType::Computational) {
27✔
563
                return false;
×
564
            }
×
565
            if (producer_memlet != nullptr) {
27✔
566
                return false;
×
567
            }
×
568
            producer_memlet = &iedge;
27✔
569
        }
27✔
570
        if (producer_memlet == nullptr) {
27✔
571
            return false;
×
572
        }
×
573

574
        const auto& producer_subset = producer_memlet->subset();
27✔
575
        if (producer_subset.empty()) {
27✔
576
            return false;
×
577
        }
×
578

579
        // Collect all unique subsets from consumer blocks
580
        std::vector<data_flow::Subset> unique_subsets;
27✔
581
        for (size_t i = 0; i < consumer_body_->size(); ++i) {
54✔
582
            auto* block = dynamic_cast<structured_control_flow::Block*>(&consumer_body_->at(i).first);
28✔
583
            if (block == nullptr) {
28✔
584
                // Skip non-block children (e.g. nested loops that are not related)
NEW
585
                continue;
×
UNCOV
586
            }
×
587

588
            auto& dataflow = block->dataflow();
28✔
589
            for (auto& node : dataflow.nodes()) {
96✔
590
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
96✔
591
                if (access == nullptr || access->data() != container) {
96✔
592
                    continue;
65✔
593
                }
65✔
594
                if (dataflow.in_degree(*access) != 0 || dataflow.out_degree(*access) == 0) {
31✔
595
                    return false;
×
596
                }
×
597

598
                // Check all read memlets from this access
599
                for (auto& memlet : dataflow.out_edges(*access)) {
32✔
600
                    if (memlet.type() != data_flow::MemletType::Computational) {
32✔
601
                        return false;
×
602
                    }
×
603

604
                    auto& consumer_subset = memlet.subset();
32✔
605
                    if (consumer_subset.size() != producer_subset.size()) {
32✔
606
                        return false;
1✔
607
                    }
1✔
608

609
                    // Check if this subset is already in unique_subsets
610
                    bool found = false;
31✔
611
                    for (const auto& existing : unique_subsets) {
31✔
612
                        if (existing.size() != consumer_subset.size()) continue;
6✔
613
                        bool match = true;
6✔
614
                        for (size_t d = 0; d < existing.size(); ++d) {
8✔
615
                            if (!symbolic::eq(existing[d], consumer_subset[d])) {
6✔
616
                                match = false;
4✔
617
                                break;
4✔
618
                            }
4✔
619
                        }
6✔
620
                        if (match) {
6✔
621
                            found = true;
2✔
622
                            break;
2✔
623
                        }
2✔
624
                    }
6✔
625
                    if (!found) {
31✔
626
                        unique_subsets.push_back(consumer_subset);
29✔
627
                    }
29✔
628
                }
31✔
629
            }
31✔
630
        }
28✔
631

632
        // For each unique consumer subset, solve index mappings and create a FusionCandidate
633
        // The direction determines which side's indvars are solved for
634
        for (const auto& consumer_subset : unique_subsets) {
29✔
635
            std::vector<std::pair<symbolic::Symbol, symbolic::Expression>> mappings;
29✔
636

637
            if (direction_ == FusionDirection::ProducerIntoConsumer) {
29✔
638
                // Solve producer indvars in terms of consumer indvars
639
                mappings = solve_subsets(
27✔
640
                    producer_subset,
27✔
641
                    consumer_subset,
27✔
642
                    producer_loops_,
27✔
643
                    consumer_loops_,
27✔
644
                    producer_assumptions,
27✔
645
                    consumer_assumptions
27✔
646
                );
27✔
647
            } else {
27✔
648
                // ConsumerIntoProducer: solve consumer indvars in terms of producer indvars
649
                // Arguments are swapped, so invert the range check direction
650
                mappings = solve_subsets(
2✔
651
                    consumer_subset,
2✔
652
                    producer_subset,
2✔
653
                    consumer_loops_,
2✔
654
                    producer_loops_,
2✔
655
                    consumer_assumptions,
2✔
656
                    producer_assumptions,
2✔
657
                    true
2✔
658
                );
2✔
659
            }
2✔
660

661
            if (mappings.empty()) {
29✔
662
                return false;
5✔
663
            }
5✔
664

665
            FusionCandidate candidate;
24✔
666
            candidate.container = container;
24✔
667
            candidate.consumer_subset = consumer_subset;
24✔
668
            candidate.index_mappings = std::move(mappings);
24✔
669

670
            fusion_candidates_.push_back(candidate);
24✔
671
        }
24✔
672
    }
26✔
673

674
    // Criterion: At least one valid fusion candidate
675
    return !fusion_candidates_.empty();
21✔
676
}
27✔
677

678
void MapFusion::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
13✔
679
    auto& sdfg = builder.subject();
13✔
680

681
    if (direction_ == FusionDirection::ProducerIntoConsumer) {
13✔
682
        // Pattern 1 + Reverse Pattern 2: Inline producer blocks into consumer's read body
683
        auto& first_dataflow = producer_block_->dataflow();
12✔
684

685
        // For each fusion candidate, create a temp and insert a producer block
686
        std::vector<std::string> candidate_temps;
12✔
687

688
        for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
25✔
689
            auto& candidate = fusion_candidates_[cand_idx];
13✔
690

691
            auto& container_type = sdfg.type(candidate.container);
13✔
692
            std::string temp_name = builder.find_new_name("_fused_tmp");
13✔
693
            types::Scalar tmp_type(container_type.primitive_type());
13✔
694
            builder.add_container(temp_name, tmp_type);
13✔
695
            candidate_temps.push_back(temp_name);
13✔
696

697
            // Insert a producer block at the beginning of the consumer's body
698
            auto& first_child = consumer_body_->at(0).first;
13✔
699
            control_flow::Assignments empty_assignments;
13✔
700
            auto& new_block = builder.add_block_before(*consumer_body_, first_child, empty_assignments);
13✔
701

702
            // Deep copy all nodes from producer block to new block
703
            std::unordered_map<const data_flow::DataFlowNode*, data_flow::DataFlowNode*> node_mapping;
13✔
704
            for (auto& node : first_dataflow.nodes()) {
42✔
705
                node_mapping[&node] = &builder.copy_node(new_block, node);
42✔
706
                auto* copied = node_mapping[&node];
42✔
707
                if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(copied)) {
42✔
708
                    if (access_node->data() == candidate.container) {
29✔
709
                        access_node->data(temp_name);
13✔
710
                    }
13✔
711
                }
29✔
712
            }
42✔
713

714
            // Add memlets with index substitution (producer indvars → consumer expressions)
715
            for (auto& edge : first_dataflow.edges()) {
29✔
716
                auto& src_node = edge.src();
29✔
717
                auto& dst_node = edge.dst();
29✔
718

719
                const types::IType* base_type = &edge.base_type();
29✔
720
                data_flow::Subset new_subset;
29✔
721
                for (const auto& dim : edge.subset()) {
32✔
722
                    auto new_dim = dim;
32✔
723
                    for (const auto& [pvar, mapping] : candidate.index_mappings) {
44✔
724
                        new_dim = symbolic::subs(new_dim, pvar, mapping);
44✔
725
                    }
44✔
726
                    new_dim = symbolic::expand(new_dim);
32✔
727
                    new_subset.push_back(new_dim);
32✔
728
                }
32✔
729

730
                // For output edges to temp scalar, use empty subset
731
                auto* dst_access = dynamic_cast<data_flow::AccessNode*>(&dst_node);
29✔
732
                if (dst_access != nullptr && dst_access->data() == candidate.container &&
29✔
733
                    first_dataflow.in_degree(*dst_access) > 0) {
29✔
734
                    new_subset.clear();
13✔
735
                    base_type = &tmp_type;
13✔
736
                }
13✔
737

738
                builder.add_memlet(
29✔
739
                    new_block,
29✔
740
                    *node_mapping[&src_node],
29✔
741
                    edge.src_conn(),
29✔
742
                    *node_mapping[&dst_node],
29✔
743
                    edge.dst_conn(),
29✔
744
                    new_subset,
29✔
745
                    *base_type,
29✔
746
                    edge.debug_info()
29✔
747
                );
29✔
748
            }
29✔
749
        }
13✔
750

751
        // Update all read accesses in consumer blocks to point to the appropriate temp
752
        size_t num_producer_blocks = fusion_candidates_.size();
12✔
753

754
        for (size_t block_idx = num_producer_blocks; block_idx < consumer_body_->size(); ++block_idx) {
25✔
755
            auto* block = dynamic_cast<structured_control_flow::Block*>(&consumer_body_->at(block_idx).first);
13✔
756
            if (block == nullptr) {
13✔
NEW
757
                continue;
×
NEW
758
            }
×
759

760
            auto& dataflow = block->dataflow();
13✔
761

762
            for (auto& node : dataflow.nodes()) {
46✔
763
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
46✔
764
                if (access == nullptr || dataflow.out_degree(*access) == 0) {
46✔
765
                    continue;
28✔
766
                }
28✔
767

768
                std::string original_container = access->data();
18✔
769

770
                for (auto& memlet : dataflow.out_edges(*access)) {
19✔
771
                    if (memlet.type() != data_flow::MemletType::Computational) {
19✔
NEW
772
                        continue;
×
NEW
773
                    }
×
774

775
                    const auto& memlet_subset = memlet.subset();
19✔
776

777
                    for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
24✔
778
                        auto& candidate = fusion_candidates_[cand_idx];
20✔
779

780
                        if (original_container != candidate.container) {
20✔
781
                            continue;
4✔
782
                        }
4✔
783

784
                        if (memlet_subset.size() != candidate.consumer_subset.size()) {
16✔
NEW
785
                            continue;
×
NEW
786
                        }
×
787

788
                        bool subset_matches = true;
16✔
789
                        for (size_t d = 0; d < memlet_subset.size(); ++d) {
34✔
790
                            if (!symbolic::eq(memlet_subset[d], candidate.consumer_subset[d])) {
19✔
791
                                subset_matches = false;
1✔
792
                                break;
1✔
793
                            }
1✔
794
                        }
19✔
795

796
                        if (!subset_matches) {
16✔
797
                            continue;
1✔
798
                        }
1✔
799

800
                        const auto& temp_name = candidate_temps[cand_idx];
15✔
801
                        auto& temp_type = sdfg.type(temp_name);
15✔
802

803
                        access->data(temp_name);
15✔
804

805
                        memlet.set_subset({});
15✔
806
                        memlet.set_base_type(temp_type);
15✔
807

808
                        for (auto& in_edge : dataflow.in_edges(*access)) {
15✔
NEW
809
                            in_edge.set_subset({});
×
NEW
810
                            in_edge.set_base_type(temp_type);
×
NEW
811
                        }
×
812

813
                        break;
15✔
814
                    }
16✔
815
                }
19✔
816
            }
18✔
817
        }
13✔
818

819
    } else {
12✔
820
        // ConsumerIntoProducer (Pattern 2): Inline consumer blocks into the producer's write body
821
        // Modify the producer block in-place to write to a temp scalar, add a writeback block
822
        // for the original array, then copy consumer blocks reading from the temp.
823

824
        std::vector<std::string> candidate_temps;
1✔
825
        auto& producer_dataflow = producer_block_->dataflow();
1✔
826

827
        for (size_t cand_idx = 0; cand_idx < fusion_candidates_.size(); ++cand_idx) {
2✔
828
            auto& candidate = fusion_candidates_[cand_idx];
1✔
829

830
            auto& container_type = sdfg.type(candidate.container);
1✔
831
            std::string temp_name = builder.find_new_name("_fused_tmp");
1✔
832
            types::Scalar tmp_type(container_type.primitive_type());
1✔
833
            builder.add_container(temp_name, tmp_type);
1✔
834
            candidate_temps.push_back(temp_name);
1✔
835

836
            // Step 1: Modify the original producer block to write to _fused_tmp
837
            data_flow::Subset original_write_subset;
1✔
838
            for (auto& node : producer_dataflow.nodes()) {
2✔
839
                auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
2✔
840
                if (access == nullptr || access->data() != candidate.container) continue;
2✔
841
                if (producer_dataflow.in_degree(*access) == 0) continue;
1✔
842

843
                // This is the write access node — save the original subset, then redirect
844
                for (auto& in_edge : producer_dataflow.in_edges(*access)) {
1✔
845
                    original_write_subset = in_edge.subset();
1✔
846
                    in_edge.set_subset({});
1✔
847
                    in_edge.set_base_type(tmp_type);
1✔
848
                }
1✔
849
                access->data(temp_name);
1✔
850
                break;
1✔
851
            }
1✔
852

853
            // Step 2: Add a writeback block: container[original_subset] = _fused_tmp
854
            control_flow::Assignments empty_assignments;
1✔
855
            auto& wb_block = builder.add_block_after(*producer_body_, *producer_block_, empty_assignments);
1✔
856
            auto& wb_src = builder.add_access(wb_block, temp_name);
1✔
857
            auto& wb_dst = builder.add_access(wb_block, candidate.container);
1✔
858
            auto& wb_tasklet = builder.add_tasklet(wb_block, data_flow::TaskletCode::assign, "_out", {"_in"});
1✔
859
            builder.add_computational_memlet(wb_block, wb_src, wb_tasklet, "_in", {});
1✔
860
            builder.add_computational_memlet(wb_block, wb_tasklet, "_out", wb_dst, original_write_subset);
1✔
861

862
            // Step 3: Copy consumer blocks after the writeback block
863
            structured_control_flow::ControlFlowNode* last_inserted = &wb_block;
1✔
864

865
            for (size_t i = 0; i < consumer_body_->size(); ++i) {
2✔
866
                auto* consumer_block = dynamic_cast<structured_control_flow::Block*>(&consumer_body_->at(i).first);
1✔
867
                if (consumer_block == nullptr) {
1✔
868
                    continue;
×
869
                }
×
870

871
                auto& consumer_dataflow = consumer_block->dataflow();
1✔
872

873
                // Check if this block reads from the fusion container
874
                bool reads_container = false;
1✔
875
                for (auto& node : consumer_dataflow.nodes()) {
3✔
876
                    auto* access = dynamic_cast<data_flow::AccessNode*>(&node);
3✔
877
                    if (access != nullptr && access->data() == candidate.container &&
3✔
878
                        consumer_dataflow.out_degree(*access) > 0) {
3✔
879
                        reads_container = true;
1✔
880
                        break;
1✔
881
                    }
1✔
882
                }
3✔
883
                if (!reads_container) {
1✔
NEW
884
                    continue;
×
NEW
885
                }
×
886

887
                // Insert a new block after the last inserted block in the producer's body
888
                auto& new_block = builder.add_block_after(*producer_body_, *last_inserted, empty_assignments);
1✔
889

890
                // Deep copy all nodes from consumer block
891
                std::unordered_map<const data_flow::DataFlowNode*, data_flow::DataFlowNode*> node_mapping;
1✔
892
                for (auto& node : consumer_dataflow.nodes()) {
3✔
893
                    node_mapping[&node] = &builder.copy_node(new_block, node);
3✔
894
                    auto* copied = node_mapping[&node];
3✔
895
                    if (auto* access_node = dynamic_cast<data_flow::AccessNode*>(copied)) {
3✔
896
                        if (access_node->data() == candidate.container) {
2✔
897
                            access_node->data(temp_name);
1✔
898
                        }
1✔
899
                    }
2✔
900
                }
3✔
901

902
                // Add memlets with index substitution (consumer indvars → producer expressions)
903
                for (auto& edge : consumer_dataflow.edges()) {
2✔
904
                    auto& src_node = edge.src();
2✔
905
                    auto& dst_node = edge.dst();
2✔
906

907
                    const types::IType* base_type = &edge.base_type();
2✔
908
                    data_flow::Subset new_subset;
2✔
909
                    for (const auto& dim : edge.subset()) {
4✔
910
                        auto new_dim = dim;
4✔
911
                        for (const auto& [cvar, mapping] : candidate.index_mappings) {
8✔
912
                            new_dim = symbolic::subs(new_dim, cvar, mapping);
8✔
913
                        }
8✔
914
                        new_dim = symbolic::expand(new_dim);
4✔
915
                        new_subset.push_back(new_dim);
4✔
916
                    }
4✔
917

918
                    // For read edges from temp scalar, use empty subset
919
                    auto* src_access = dynamic_cast<data_flow::AccessNode*>(&src_node);
2✔
920
                    if (src_access != nullptr && src_access->data() == candidate.container &&
2✔
921
                        consumer_dataflow.in_degree(*src_access) == 0) {
2✔
922
                        new_subset.clear();
1✔
923
                        base_type = &tmp_type;
1✔
924
                    }
1✔
925

926
                    builder.add_memlet(
2✔
927
                        new_block,
2✔
928
                        *node_mapping[&src_node],
2✔
929
                        edge.src_conn(),
2✔
930
                        *node_mapping[&dst_node],
2✔
931
                        edge.dst_conn(),
2✔
932
                        new_subset,
2✔
933
                        *base_type,
2✔
934
                        edge.debug_info()
2✔
935
                    );
2✔
936
                }
2✔
937

938
                last_inserted = &new_block;
1✔
939
            }
1✔
940
        }
1✔
941

942
        // Remove the consumer loop
943
        auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
1✔
944
        auto* parent = scope_analysis.parent_scope(&second_loop_);
1✔
945
        auto* parent_seq = dynamic_cast<structured_control_flow::Sequence*>(parent);
1✔
946
        if (parent_seq != nullptr) {
1✔
947
            int idx = parent_seq->index(second_loop_);
1✔
948
            if (idx >= 0) {
1✔
949
                builder.remove_child(*parent_seq, static_cast<size_t>(idx));
1✔
950
            }
1✔
951
        }
1✔
952
    }
1✔
953

954
    analysis_manager.invalidate_all();
13✔
955
    applied_ = true;
13✔
956
}
13✔
957

958
void MapFusion::to_json(nlohmann::json& j) const {
1✔
959
    std::string second_type = "for";
1✔
960
    if (dynamic_cast<structured_control_flow::Map*>(&second_loop_) != nullptr) {
1✔
961
        second_type = "map";
1✔
962
    }
1✔
963
    j["transformation_type"] = this->name();
1✔
964
    j["subgraph"] = {
1✔
965
        {"0", {{"element_id", first_map_.element_id()}, {"type", "map"}}},
1✔
966
        {"1", {{"element_id", second_loop_.element_id()}, {"type", second_type}}}
1✔
967
    };
1✔
968
}
1✔
969

970
MapFusion MapFusion::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& desc) {
1✔
971
    auto first_map_id = desc["subgraph"]["0"]["element_id"].get<size_t>();
1✔
972
    auto second_loop_id = desc["subgraph"]["1"]["element_id"].get<size_t>();
1✔
973

974
    auto first_element = builder.find_element_by_id(first_map_id);
1✔
975
    auto second_element = builder.find_element_by_id(second_loop_id);
1✔
976

977
    if (first_element == nullptr) {
1✔
978
        throw InvalidTransformationDescriptionException("Element with ID " + std::to_string(first_map_id) + " not found.");
×
979
    }
×
980
    if (second_element == nullptr) {
1✔
981
        throw InvalidTransformationDescriptionException(
×
982
            "Element with ID " + std::to_string(second_loop_id) + " not found."
×
983
        );
×
984
    }
×
985

986
    auto* first_map = dynamic_cast<structured_control_flow::Map*>(first_element);
1✔
987
    auto* second_loop = dynamic_cast<structured_control_flow::StructuredLoop*>(second_element);
1✔
988

989
    if (first_map == nullptr) {
1✔
990
        throw InvalidTransformationDescriptionException(
×
991
            "Element with ID " + std::to_string(first_map_id) + " is not a Map."
×
992
        );
×
993
    }
×
994
    if (second_loop == nullptr) {
1✔
995
        throw InvalidTransformationDescriptionException(
×
996
            "Element with ID " + std::to_string(second_loop_id) + " is not a StructuredLoop."
×
997
        );
×
998
    }
×
999

1000
    return MapFusion(*first_map, *second_loop);
1✔
1001
}
1✔
1002

1003
} // namespace transformations
1004
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc