• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27462544866

13 Jun 2026 09:11AM UTC coverage: 61.274% (-0.06%) from 61.331%
27462544866

push

github

web-flow
simplifies local storage transformations (#758)

275 of 325 new or added lines in 2 files covered. (84.62%)

5 existing lines in 1 file now uncovered.

36247 of 59156 relevant lines covered (61.27%)

1124.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.24
/opt/src/transformations/out_local_storage.cpp
1
#include "sdfg/transformations/out_local_storage.h"
2

3
#include <cstddef>
4
#include <functional>
5
#include <string>
6

7
#include "sdfg/analysis/memory_layout_analysis.h"
8
#include "sdfg/analysis/users.h"
9
#include "sdfg/builder/structured_sdfg_builder.h"
10
#include "sdfg/data_flow/access_node.h"
11
#include "sdfg/data_flow/library_nodes/barrier_local_node.h"
12
#include "sdfg/data_flow/memlet.h"
13
#include "sdfg/passes/structured_control_flow/dead_cfg_elimination.h"
14
#include "sdfg/passes/structured_control_flow/sequence_fusion.h"
15
#include "sdfg/structured_control_flow/if_else.h"
16
#include "sdfg/structured_control_flow/sequence.h"
17
#include "sdfg/structured_control_flow/structured_loop.h"
18
#include "sdfg/symbolic/symbolic.h"
19
#include "sdfg/targets/gpu/gpu_schedule_type.h"
20
#include "sdfg/types/array.h"
21
#include "sdfg/types/pointer.h"
22
#include "sdfg/types/scalar.h"
23

24
namespace sdfg {
25
namespace transformations {
26

27
OutLocalStorage::OutLocalStorage(
28
    structured_control_flow::StructuredLoop& loop,
29
    const data_flow::AccessNode& access_node,
30
    const types::StorageType& storage_type
31
)
32
    : loop_(loop), access_node_(access_node), container_(access_node.data()), storage_type_(storage_type) {};
29✔
33

34
std::string OutLocalStorage::name() const { return "OutLocalStorage"; };
5✔
35

36
bool OutLocalStorage::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
27✔
37
    auto& sdfg = builder.subject();
27✔
38
    auto& body = this->loop_.root();
27✔
39

40
    tile_info_ = TileInfo{};
27✔
41

42
    // Criterion: Container must exist and is pointer
43
    if (!sdfg.exists(this->container_)) {
27✔
44
        return false;
×
45
    }
×
46
    auto& type = sdfg.type(this->container_);
27✔
47
    if (type.type_id() != types::TypeID::Pointer) {
27✔
NEW
48
        return false;
×
NEW
49
    }
×
50

51
    // Criterion: Container must be used in the loop body
52
    auto& users = analysis_manager.get<analysis::Users>();
27✔
53
    analysis::UsersView body_users(users, body);
27✔
54
    if (body_users.uses(this->container_).empty()) {
27✔
55
        return false;
2✔
56
    }
2✔
57

58
    // Criterion: Container must have writes (this is OutLocalStorage, not InLocalStorage)
59
    if (body_users.writes(this->container_).empty()) {
25✔
60
        return false;
1✔
61
    }
1✔
62

63
    // Determine if container is also read (read-write vs write-only)
64
    tile_info_.has_read = !body_users.reads(this->container_).empty();
24✔
65

66
    auto& mla = analysis_manager.get<analysis::MemoryLayoutAnalysis>();
24✔
67

68
    // Find a representative memlet from the access node to identify its group.
69
    // An access node may have multiple edges belonging to different tile groups.
70
    // We iterate all edges and select the first one whose tile group is valid
71
    // at the target loop level.
72
    const analysis::MemoryTileGroup* group = nullptr;
24✔
73
    auto& dfg = access_node_.get_parent();
24✔
74
    for (auto& memlet : dfg.in_edges(access_node_)) {
24✔
75
        auto* candidate = mla.tile_group_for(loop_, memlet);
13✔
76
        if (candidate) {
13✔
77
            group = candidate;
13✔
78
            break;
13✔
79
        }
13✔
80
    }
13✔
81
    if (!group) {
24✔
82
        for (auto& memlet : dfg.out_edges(access_node_)) {
11✔
83
            auto* candidate = mla.tile_group_for(loop_, memlet);
11✔
84
            if (candidate) {
11✔
85
                group = candidate;
11✔
86
                break;
11✔
87
            }
11✔
88
        }
11✔
89
    }
11✔
90
    if (!group) {
24✔
91
        return false;
×
92
    }
×
93

94
    auto& tile = group->tile;
24✔
95

96
    // Store group memlets for use in apply()
97
    group_memlets_.clear();
24✔
98
    group_memlets_.insert(group->memlets.begin(), group->memlets.end());
24✔
99

100
    // Get overapproximated extents (integer upper bounds)
101
    auto extents = tile.extents_approx();
24✔
102
    if (extents.empty()) {
24✔
103
        return false;
×
104
    }
×
105
    for (auto& ext : extents) {
41✔
106
        if (ext.is_null()) {
41✔
NEW
107
            return false;
×
NEW
108
        }
×
109
    }
41✔
110

111
    // Store tile info (before substitution, bases/strides stay symbolic)
112
    tile_info_.dimensions = extents;
24✔
113
    tile_info_.bases = tile.min_subset;
24✔
114
    tile_info_.strides = std::vector<symbolic::Expression>(tile.layout.strides().begin(), tile.layout.strides().end());
24✔
115
    tile_info_.offset = tile.layout.offset();
24✔
116

117
    // GPU shared memory: resolve symbolic extents using GPU block sizes and
118
    // require at least one cooperative dimension
119
    if (storage_type_.is_nv_shared()) {
24✔
120
        auto ancestors = ControlFlowNode::parent_chain(loop_);
6✔
121

122
        // Build substitution map: symbolic GPU map bounds → integer block sizes
123
        for (auto* node : ancestors) {
26✔
124
            if (auto* ancestor_map = dynamic_cast<structured_control_flow::Map*>(node)) {
26✔
125
                if (!gpu::is_gpu_schedule(ancestor_map->schedule_type())) {
10✔
126
                    continue;
×
127
                }
×
128
                auto block_size = gpu::gpu_block_size(ancestor_map->schedule_type());
10✔
129
                // Extract symbolic bound from condition: Lt(indvar, BOUND)
130
                auto condition = ancestor_map->condition();
10✔
131
                if (SymEngine::is_a<SymEngine::StrictLessThan>(*condition)) {
10✔
132
                    auto stl = SymEngine::rcp_static_cast<const SymEngine::StrictLessThan>(condition);
10✔
133
                    auto rhs = stl->get_args()[1];
10✔
134
                    auto iter_count = symbolic::sub(rhs, ancestor_map->init());
10✔
135
                    if (!SymEngine::is_a<SymEngine::Integer>(*iter_count)) {
10✔
136
                        // Symbolic bound — substitute with block size in extents and bases
137
                        for (auto& ext : tile_info_.dimensions) {
17✔
138
                            ext = symbolic::simplify(symbolic::subs(ext, iter_count, block_size));
17✔
139
                        }
17✔
140
                        for (auto& base : tile_info_.bases) {
17✔
141
                            base = symbolic::simplify(symbolic::subs(base, iter_count, block_size));
17✔
142
                        }
17✔
143
                    }
10✔
144
                }
10✔
145
            }
10✔
146
        }
26✔
147

148
        // Criterion: All extents must now be provably integer
149
        for (auto& ext : tile_info_.dimensions) {
10✔
150
            if (!SymEngine::is_a<SymEngine::Integer>(*ext)) {
10✔
151
                return false;
2✔
152
            }
2✔
153
        }
10✔
154

155
        // Criterion: At least one cooperative dimension
156
        bool has_cooperative_dim = false;
4✔
157
        for (auto* node : ancestors) {
12✔
158
            if (auto* ancestor_map = dynamic_cast<structured_control_flow::Map*>(node)) {
12✔
159
                if (!gpu::is_gpu_schedule(ancestor_map->schedule_type())) {
6✔
160
                    continue;
×
161
                }
×
162
                bool appears_in_bases = false;
6✔
163
                for (auto& base : tile_info_.bases) {
9✔
164
                    if (symbolic::uses(base, ancestor_map->indvar())) {
9✔
165
                        appears_in_bases = true;
2✔
166
                        break;
2✔
167
                    }
2✔
168
                }
9✔
169
                if (!appears_in_bases) {
6✔
170
                    has_cooperative_dim = true;
4✔
171
                    break;
4✔
172
                }
4✔
173
            }
6✔
174
        }
12✔
175
        if (!has_cooperative_dim) {
4✔
176
            return false;
×
177
        }
×
178
    } else {
18✔
179
        // CPU path: All extents must be provably integer
180
        for (auto& ext : tile_info_.dimensions) {
31✔
181
            if (!SymEngine::is_a<SymEngine::Integer>(*ext)) {
31✔
182
                return false;
×
183
            }
×
184
        }
31✔
185
    }
18✔
186

187
    return true;
22✔
188
}
24✔
189

190
void OutLocalStorage::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
20✔
191
    auto& sdfg = builder.subject();
20✔
192
    auto& users = analysis_manager.get<analysis::Users>();
20✔
193

194
    auto parent_node = loop_.get_parent();
20✔
195
    auto parent = dynamic_cast<structured_control_flow::Sequence*>(parent_node);
20✔
196
    if (!parent) {
20✔
197
        throw InvalidSDFGException("OutLocalStorage: Parent of loop must be a Sequence!");
×
198
    }
×
199

200
    // Get type information.
201
    auto* memlet = *group_memlets_.begin();
20✔
202
    types::Scalar scalar_type(memlet->base_type().primitive_type());
20✔
203
    types::Pointer pointer_type(scalar_type);
20✔
204

205
    // Create local buffer name
206
    local_name_ = builder.find_new_name("__daisy_out_local_storage_" + this->container_);
20✔
207

208

209
    // Collect varying dimensions (extent > 1) and compute buffer layout.
210
    // Extent-1 dimensions are degenerate (no loop is needed) and must be
211
    // skipped when sizing the buffer, when creating copy indvars, and when
212
    // linearizing into the local buffer.  The bookkeeping must match what
213
    // `build_original_subset` expects (it indexes copy_indices by varying
214
    // dimension only).
215
    std::vector<size_t> varying_dims;
20✔
216
    std::vector<symbolic::Expression> dim_sizes;
20✔
217
    for (size_t d = 0; d < tile_info_.dimensions.size(); d++) {
55✔
218
        auto& dim_size = tile_info_.dimensions.at(d);
35✔
219
        if (!symbolic::eq(dim_size, symbolic::integer(1))) {
35✔
220
            varying_dims.push_back(d);
20✔
221
            dim_sizes.push_back(dim_size);
20✔
222
        }
20✔
223
    }
35✔
224

225
    // Compute total buffer size
226
    symbolic::Expression total_size = symbolic::integer(1);
20✔
227
    for (auto& ds : dim_sizes) {
20✔
228
        total_size = symbolic::mul(total_size, ds);
20✔
229
    }
20✔
230

231
    // Create the local buffer with specified storage type
232
    types::Array buffer_type(storage_type_, 0, {}, scalar_type, total_size);
20✔
233
    builder.add_container(local_name_, buffer_type);
20✔
234

235
    // Helper: build linearized local index from per-dimension expressions
236
    auto linearize_exprs = [&](const std::vector<symbolic::Expression>& indices) -> symbolic::Expression {
65✔
237
        symbolic::Expression linear_idx = symbolic::integer(0);
65✔
238
        symbolic::Expression stride = symbolic::integer(1);
65✔
239
        for (int i = indices.size() - 1; i >= 0; i--) {
128✔
240
            linear_idx = symbolic::add(linear_idx, symbolic::mul(indices[i], stride));
63✔
241
            stride = symbolic::mul(stride, dim_sizes[i]);
63✔
242
        }
63✔
243
        return linear_idx;
65✔
244
    };
65✔
245

246
    // Helper: build linearized local index from per-dimension indvars (symbols)
247
    auto linearize = [&](const std::vector<symbolic::Symbol>& indvars) -> symbolic::Expression {
29✔
248
        std::vector<symbolic::Expression> exprs(indvars.begin(), indvars.end());
29✔
249
        return linearize_exprs(exprs);
29✔
250
    };
29✔
251

252
    // Helper: build source subset (base[d] + copy_indvar[d]) for original container
253
    auto build_original_subset = [&](const std::vector<symbolic::Expression>& copy_indices) -> data_flow::Subset {
34✔
254
        std::vector<symbolic::Expression> full_indices;
34✔
255
        size_t var_idx = 0;
34✔
256
        for (size_t d = 0; d < tile_info_.dimensions.size(); d++) {
95✔
257
            if (!symbolic::eq(tile_info_.dimensions.at(d), symbolic::integer(1))) {
61✔
258
                full_indices.push_back(symbolic::add(tile_info_.bases.at(d), copy_indices.at(var_idx++)));
34✔
259
            } else {
34✔
260
                full_indices.push_back(tile_info_.bases.at(d));
27✔
261
            }
27✔
262
        }
61✔
263

264
        symbolic::Expression linear = tile_info_.offset;
34✔
265
        for (size_t d = 0; d < full_indices.size(); d++) {
95✔
266
            linear = symbolic::add(linear, symbolic::mul(tile_info_.strides.at(d), full_indices.at(d)));
61✔
267
        }
61✔
268
        return {linear};
34✔
269
    };
34✔
270

271
    if (storage_type_.is_nv_shared()) {
20✔
272
        // ============================================================
273
        // GPU COOPERATIVE PATH
274
        // ============================================================
275
        auto ancestors = ControlFlowNode::parent_chain(loop_);
4✔
276

277
        // Collect cooperative GPU dimensions
278
        struct CoopDim {
4✔
279
            symbolic::Symbol indvar;
4✔
280
            symbolic::Integer block_size;
4✔
281
            gpu::GPUDimension dimension;
4✔
282
        };
4✔
283
        std::vector<CoopDim> coop_dims;
4✔
284

285
        for (auto* node : ancestors) {
20✔
286
            if (auto* ancestor_map = dynamic_cast<structured_control_flow::Map*>(node)) {
20✔
287
                if (!gpu::is_gpu_schedule(ancestor_map->schedule_type())) {
8✔
NEW
288
                    continue;
×
NEW
289
                }
×
290
                bool appears_in_bases = false;
8✔
291
                for (auto& base : tile_info_.bases) {
11✔
292
                    if (symbolic::uses(base, ancestor_map->indvar())) {
11✔
293
                        appears_in_bases = true;
3✔
294
                        break;
3✔
295
                    }
3✔
296
                }
11✔
297
                if (!appears_in_bases) {
8✔
298
                    coop_dims.push_back(
5✔
299
                        {ancestor_map->indvar(),
5✔
300
                         gpu::gpu_block_size(ancestor_map->schedule_type()),
5✔
301
                         gpu::gpu_dimension(ancestor_map->schedule_type())}
5✔
302
                    );
5✔
303
                }
5✔
304
            }
8✔
305
        }
20✔
306

307
        // Compute total cooperative thread count
308
        symbolic::Expression total_coop_threads = symbolic::integer(1);
4✔
309
        for (auto& cd : coop_dims) {
5✔
310
            total_coop_threads = symbolic::mul(total_coop_threads, cd.block_size);
5✔
311
        }
5✔
312

313
        // Flatten cooperative thread index
314
        symbolic::Expression coop_flat = symbolic::integer(0);
4✔
315
        symbolic::Expression coop_stride = symbolic::integer(1);
4✔
316
        for (int i = coop_dims.size() - 1; i >= 0; i--) {
9✔
317
            coop_flat = symbolic::add(coop_flat, symbolic::mul(coop_dims[i].indvar, coop_stride));
5✔
318
            coop_stride = symbolic::mul(coop_stride, coop_dims[i].block_size);
5✔
319
        }
5✔
320

321
        // INIT: barrier → cooperative copy-in → barrier (if has_read)
322
        if (tile_info_.has_read) {
4✔
323
            // Barrier before init
324
            auto& barrier_block1 = builder.add_block_before(*parent, loop_, {}, loop_.debug_info());
1✔
325
            builder.add_library_node<data_flow::BarrierLocalNode>(barrier_block1, {});
1✔
326

327
            // Cooperative copy-in loop
328
            auto idx_name = builder.find_new_name("__daisy_ols_coop_init_" + this->container_);
1✔
329
            types::Scalar idx_type(types::PrimitiveType::UInt64);
1✔
330
            builder.add_container(idx_name, idx_type);
1✔
331
            auto idx_var = symbolic::symbol(idx_name);
1✔
332

333
            auto& init_loop = builder.add_map_before(
1✔
334
                *parent,
1✔
335
                loop_,
1✔
336
                idx_var,
1✔
337
                symbolic::Lt(idx_var, total_size),
1✔
338
                coop_flat,
1✔
339
                symbolic::add(idx_var, total_coop_threads),
1✔
340
                structured_control_flow::ScheduleType_Sequential::create(),
1✔
341
                {},
1✔
342
                loop_.debug_info()
1✔
343
            );
1✔
344

345
            auto& init_block = builder.add_block(init_loop.root());
1✔
346
            auto& init_src = builder.add_access(init_block, this->container_);
1✔
347
            auto& init_dst = builder.add_access(init_block, local_name_);
1✔
348
            auto& init_tasklet = builder.add_tasklet(init_block, data_flow::TaskletCode::assign, "_out", {"_in"});
1✔
349

350
            // Decompose idx_var into per-dim indices over varying dims only
351
            std::vector<symbolic::Expression> init_indices;
1✔
352
            symbolic::Expression remainder = idx_var;
1✔
353
            for (size_t i = 0; i < dim_sizes.size(); i++) {
2✔
354
                if (i < dim_sizes.size() - 1) {
1✔
NEW
355
                    symbolic::Expression divisor = symbolic::integer(1);
×
NEW
356
                    for (size_t j = i + 1; j < dim_sizes.size(); j++) {
×
NEW
357
                        divisor = symbolic::mul(divisor, dim_sizes[j]);
×
UNCOV
358
                    }
×
NEW
359
                    init_indices.push_back(symbolic::div(remainder, divisor));
×
NEW
360
                    remainder = symbolic::mod(remainder, divisor);
×
361
                } else {
1✔
362
                    init_indices.push_back(remainder);
1✔
363
                }
1✔
364
            }
1✔
365

366
            auto init_src_subset = build_original_subset(init_indices);
1✔
367
            builder.add_computational_memlet(init_block, init_src, init_tasklet, "_in", init_src_subset, pointer_type);
1✔
368
            builder.add_computational_memlet(init_block, init_tasklet, "_out", init_dst, {idx_var}, buffer_type);
1✔
369

370
            // Barrier after init
371
            auto& barrier_block2 = builder.add_block_before(*parent, loop_, {}, loop_.debug_info());
1✔
372
            builder.add_library_node<data_flow::BarrierLocalNode>(barrier_block2, {});
1✔
373
        }
1✔
374

375
        // WRITEBACK: barrier → cooperative copy-out → barrier
376
        {
4✔
377
            // Barrier before writeback
378
            auto& barrier_block3 = builder.add_block_after(*parent, loop_, {}, loop_.debug_info());
4✔
379
            builder.add_library_node<data_flow::BarrierLocalNode>(barrier_block3, {});
4✔
380

381
            // Cooperative writeback loop
382
            auto idx_name = builder.find_new_name("__daisy_ols_coop_wb_" + this->container_);
4✔
383
            types::Scalar idx_type(types::PrimitiveType::UInt64);
4✔
384
            builder.add_container(idx_name, idx_type);
4✔
385
            auto idx_var = symbolic::symbol(idx_name);
4✔
386

387
            auto& wb_loop = builder.add_map_after(
4✔
388
                *parent,
4✔
389
                loop_,
4✔
390
                idx_var,
4✔
391
                symbolic::Lt(idx_var, total_size),
4✔
392
                coop_flat,
4✔
393
                symbolic::add(idx_var, total_coop_threads),
4✔
394
                structured_control_flow::ScheduleType_Sequential::create(),
4✔
395
                {},
4✔
396
                loop_.debug_info()
4✔
397
            );
4✔
398

399
            auto& wb_block = builder.add_block(wb_loop.root());
4✔
400
            auto& wb_src = builder.add_access(wb_block, local_name_);
4✔
401
            auto& wb_dst = builder.add_access(wb_block, this->container_);
4✔
402
            auto& wb_tasklet = builder.add_tasklet(wb_block, data_flow::TaskletCode::assign, "_out", {"_in"});
4✔
403

404
            // Decompose idx_var into per-dim indices over varying dims only
405
            std::vector<symbolic::Expression> wb_indices;
4✔
406
            symbolic::Expression remainder = idx_var;
4✔
407
            for (size_t i = 0; i < dim_sizes.size(); i++) {
8✔
408
                if (i < dim_sizes.size() - 1) {
4✔
NEW
409
                    symbolic::Expression divisor = symbolic::integer(1);
×
NEW
410
                    for (size_t j = i + 1; j < dim_sizes.size(); j++) {
×
NEW
411
                        divisor = symbolic::mul(divisor, dim_sizes[j]);
×
NEW
412
                    }
×
NEW
413
                    wb_indices.push_back(symbolic::div(remainder, divisor));
×
NEW
414
                    remainder = symbolic::mod(remainder, divisor);
×
415
                } else {
4✔
416
                    wb_indices.push_back(remainder);
4✔
417
                }
4✔
418
            }
4✔
419

420
            auto wb_dst_subset = build_original_subset(wb_indices);
4✔
421
            builder.add_computational_memlet(wb_block, wb_src, wb_tasklet, "_in", {idx_var}, buffer_type);
4✔
422
            builder.add_computational_memlet(wb_block, wb_tasklet, "_out", wb_dst, wb_dst_subset, pointer_type);
4✔
423

424
            // Barrier after writeback
425
            auto& barrier_block4 = builder.add_block_after(*parent, loop_, {}, loop_.debug_info());
4✔
426
            builder.add_library_node<data_flow::BarrierLocalNode>(barrier_block4, {});
4✔
427
        }
4✔
428
    } else {
16✔
429
        // ============================================================
430
        // CPU SEQUENTIAL PATH
431
        // ============================================================
432
        if (tile_info_.has_read) {
16✔
433
            std::vector<symbolic::Symbol> init_indvars;
13✔
434
            structured_control_flow::Sequence* init_scope =
13✔
435
                &builder.add_sequence_before(*parent, loop_, {}, loop_.debug_info());
13✔
436
            for (size_t i = 0; i < varying_dims.size(); i++) {
26✔
437
                size_t d = varying_dims[i];
13✔
438
                auto indvar_name =
13✔
439
                    builder.find_new_name("__daisy_ols_init_" + this->container_ + "_d" + std::to_string(d));
13✔
440
                types::Scalar indvar_type(types::PrimitiveType::UInt64);
13✔
441
                builder.add_container(indvar_name, indvar_type);
13✔
442
                auto indvar = symbolic::symbol(indvar_name);
13✔
443
                init_indvars.push_back(indvar);
13✔
444

445
                auto init = symbolic::integer(0);
13✔
446
                auto condition = symbolic::Lt(indvar, dim_sizes[i]);
13✔
447
                auto update = symbolic::add(indvar, symbolic::integer(1));
13✔
448

449
                auto& init_loop = builder.add_map(
13✔
450
                    *init_scope,
13✔
451
                    indvar,
13✔
452
                    condition,
13✔
453
                    init,
13✔
454
                    update,
13✔
455
                    structured_control_flow::ScheduleType_Sequential::create(),
13✔
456
                    {},
13✔
457
                    loop_.debug_info()
13✔
458
                );
13✔
459
                init_scope = &init_loop.root();
13✔
460
            }
13✔
461

462
            // Create init copy block
463
            auto& init_block = builder.add_block(*init_scope);
13✔
464
            auto& init_src = builder.add_access(init_block, this->container_);
13✔
465
            auto& init_dst = builder.add_access(init_block, local_name_);
13✔
466
            auto& init_tasklet = builder.add_tasklet(init_block, data_flow::TaskletCode::assign, "_out", {"_in"});
13✔
467

468
            std::vector<symbolic::Expression> init_exprs(init_indvars.begin(), init_indvars.end());
13✔
469
            auto init_src_subset = build_original_subset(init_exprs);
13✔
470
            data_flow::Subset init_dst_subset = {linearize(init_indvars)};
13✔
471

472
            builder.add_computational_memlet(init_block, init_src, init_tasklet, "_in", init_src_subset, pointer_type);
13✔
473
            builder.add_computational_memlet(init_block, init_tasklet, "_out", init_dst, init_dst_subset, buffer_type);
13✔
474
        }
13✔
475

476
        // Writeback Maps
477
        {
16✔
478
            std::vector<symbolic::Symbol> wb_indvars;
16✔
479
            structured_control_flow::Sequence* wb_scope =
16✔
480
                &builder.add_sequence_after(*parent, loop_, {}, loop_.debug_info());
16✔
481
            for (size_t i = 0; i < varying_dims.size(); i++) {
32✔
482
                size_t d = varying_dims[i];
16✔
483
                auto indvar_name =
16✔
484
                    builder.find_new_name("__daisy_ols_wb_" + this->container_ + "_d" + std::to_string(d));
16✔
485
                types::Scalar indvar_type(types::PrimitiveType::UInt64);
16✔
486
                builder.add_container(indvar_name, indvar_type);
16✔
487
                auto indvar = symbolic::symbol(indvar_name);
16✔
488
                wb_indvars.push_back(indvar);
16✔
489

490
                auto init = symbolic::integer(0);
16✔
491
                auto condition = symbolic::Lt(indvar, dim_sizes[i]);
16✔
492
                auto update = symbolic::add(indvar, symbolic::integer(1));
16✔
493

494
                auto& wb_loop = builder.add_map(
16✔
495
                    *wb_scope,
16✔
496
                    indvar,
16✔
497
                    condition,
16✔
498
                    init,
16✔
499
                    update,
16✔
500
                    structured_control_flow::ScheduleType_Sequential::create(),
16✔
501
                    {},
16✔
502
                    loop_.debug_info()
16✔
503
                );
16✔
504
                wb_scope = &wb_loop.root();
16✔
505
            }
16✔
506

507
            // Create writeback copy block
508
            auto& wb_block = builder.add_block(*wb_scope);
16✔
509
            auto& wb_src = builder.add_access(wb_block, local_name_);
16✔
510
            auto& wb_dst = builder.add_access(wb_block, this->container_);
16✔
511
            auto& wb_tasklet = builder.add_tasklet(wb_block, data_flow::TaskletCode::assign, "_out", {"_in"});
16✔
512

513
            std::vector<symbolic::Expression> wb_exprs(wb_indvars.begin(), wb_indvars.end());
16✔
514
            data_flow::Subset wb_src_subset = {linearize(wb_indvars)};
16✔
515
            auto wb_dst_subset = build_original_subset(wb_exprs);
16✔
516

517
            builder.add_computational_memlet(wb_block, wb_src, wb_tasklet, "_in", wb_src_subset, buffer_type);
16✔
518
            builder.add_computational_memlet(wb_block, wb_tasklet, "_out", wb_dst, wb_dst_subset, pointer_type);
16✔
519
        }
16✔
520
    }
16✔
521

522
    // ==================================================================
523
    // Update accesses in the main loop to use the local buffer
524
    // ==================================================================
525
    auto& mla = analysis_manager.get<analysis::MemoryLayoutAnalysis>();
20✔
526

527
    // Recursive helper to traverse all blocks in the loop body
528
    std::function<void(structured_control_flow::ControlFlowNode&)> rewrite_accesses;
20✔
529
    rewrite_accesses = [&](structured_control_flow::ControlFlowNode& node) {
63✔
530
        if (auto* block = dynamic_cast<structured_control_flow::Block*>(&node)) {
63✔
531
            auto& dfg = block->dataflow();
23✔
532

533
            // Collect access nodes to process (avoid iterator invalidation when splitting)
534
            std::vector<data_flow::AccessNode*> access_nodes;
23✔
535
            for (auto* access_node : dfg.data_nodes()) {
65✔
536
                if (access_node->data() == this->container_) {
65✔
537
                    access_nodes.push_back(access_node);
36✔
538
                }
36✔
539
            }
65✔
540

541
            for (auto* access : access_nodes) {
36✔
542
                // Classify memlets: group vs non-group
543
                struct MemletRewrite {
36✔
544
                    data_flow::Memlet* memlet;
36✔
545
                    data_flow::Subset local_subset;
36✔
546
                    bool is_outgoing;
36✔
547
                };
36✔
548
                std::vector<MemletRewrite> group_rewrites;
36✔
549
                bool all_in_group = true;
36✔
550

551
                // Outgoing memlets (reads from this access node)
552
                for (auto& memlet : dfg.out_edges(*access)) {
36✔
553
                    if (group_memlets_.count(&memlet) == 0) {
15✔
NEW
554
                        all_in_group = false;
×
NEW
555
                        continue;
×
UNCOV
556
                    }
×
557
                    auto* acc = mla.access(memlet);
15✔
558
                    if (acc && acc->subset.size() == tile_info_.dimensions.size()) {
15✔
559
                        std::vector<symbolic::Expression> local_indices;
15✔
560
                        for (size_t d = 0; d < tile_info_.dimensions.size(); d++) {
43✔
561
                            if (!symbolic::eq(tile_info_.dimensions.at(d), symbolic::integer(1))) {
28✔
562
                                local_indices.push_back(symbolic::sub(acc->subset.at(d), tile_info_.bases.at(d)));
14✔
563
                            }
14✔
564
                        }
28✔
565
                        symbolic::Expression linear_idx = linearize_exprs(local_indices);
15✔
566
                        group_rewrites.push_back({&memlet, {linear_idx}, true});
15✔
567
                    } else {
15✔
568
                        // Memlet is claimed by the group but we cannot rewrite it (no
569
                        // delinearized access info). Leaving it as the original container
570
                        // would create a half-renamed access node. Bail out of renaming
571
                        // to keep the SDFG consistent.
NEW
572
                        all_in_group = false;
×
UNCOV
573
                    }
×
574
                }
15✔
575
                // Incoming memlets (writes to this access node)
576
                for (auto& memlet : dfg.in_edges(*access)) {
36✔
577
                    if (group_memlets_.count(&memlet) == 0) {
21✔
NEW
578
                        all_in_group = false;
×
NEW
579
                        continue;
×
NEW
580
                    }
×
581
                    auto* acc = mla.access(memlet);
21✔
582
                    if (acc && acc->subset.size() == tile_info_.dimensions.size()) {
21✔
583
                        std::vector<symbolic::Expression> local_indices;
21✔
584
                        for (size_t d = 0; d < tile_info_.dimensions.size(); d++) {
58✔
585
                            if (!symbolic::eq(tile_info_.dimensions.at(d), symbolic::integer(1))) {
37✔
586
                                local_indices.push_back(symbolic::sub(acc->subset.at(d), tile_info_.bases.at(d)));
20✔
587
                            }
20✔
588
                        }
37✔
589
                        symbolic::Expression linear_idx = linearize_exprs(local_indices);
21✔
590
                        group_rewrites.push_back({&memlet, {linear_idx}, false});
21✔
591
                    } else {
21✔
NEW
592
                        all_in_group = false;
×
UNCOV
593
                    }
×
594
                }
21✔
595

596
                if (group_rewrites.empty()) continue;
36✔
597

598
                if (all_in_group) {
36✔
599
                    // Simple case: all memlets in group → rewrite in-place and rename
600
                    for (auto& rw : group_rewrites) {
36✔
601
                        rw.memlet->set_subset(rw.local_subset);
36✔
602
                        rw.memlet->set_base_type(buffer_type);
36✔
603
                    }
36✔
604
                    access->data(local_name_);
36✔
605
                } else {
36✔
606
                    // Mixed case: split — create new local access node, redirect group memlets
NEW
607
                    auto& local_access = builder.add_access(*block, local_name_);
×
NEW
608
                    for (auto& rw : group_rewrites) {
×
NEW
609
                        if (rw.is_outgoing) {
×
610
                            // outgoing: access→tasklet  →  local_access→tasklet
NEW
611
                            auto& dst_node = rw.memlet->dst();
×
NEW
612
                            auto dst_conn = rw.memlet->dst_conn();
×
NEW
613
                            builder.remove_memlet(*block, *rw.memlet);
×
NEW
614
                            builder.add_memlet(
×
NEW
615
                                *block, local_access, "void", dst_node, dst_conn, rw.local_subset, buffer_type, {}
×
NEW
616
                            );
×
NEW
617
                        } else {
×
618
                            // incoming: tasklet→access  →  tasklet→local_access
NEW
619
                            auto& src_node = rw.memlet->src();
×
NEW
620
                            auto src_conn = rw.memlet->src_conn();
×
NEW
621
                            builder.remove_memlet(*block, *rw.memlet);
×
NEW
622
                            builder.add_memlet(
×
NEW
623
                                *block, src_node, src_conn, local_access, "void", rw.local_subset, buffer_type, {}
×
NEW
624
                            );
×
625
                        }
×
626
                    }
×
UNCOV
627
                }
×
628
            }
36✔
629
        } else if (auto* seq = dynamic_cast<structured_control_flow::Sequence*>(&node)) {
40✔
630
            for (size_t i = 0; i < seq->size(); i++) {
63✔
631
                rewrite_accesses(seq->at(i).first);
33✔
632
            }
33✔
633
        } else if (auto* loop = dynamic_cast<structured_control_flow::StructuredLoop*>(&node)) {
30✔
634
            rewrite_accesses(loop->root());
10✔
635
        } else if (auto* if_else = dynamic_cast<structured_control_flow::IfElse*>(&node)) {
10✔
NEW
636
            for (size_t i = 0; i < if_else->size(); i++) {
×
NEW
637
                rewrite_accesses(if_else->at(i).first);
×
NEW
638
            }
×
NEW
639
        }
×
640
    };
63✔
641
    rewrite_accesses(loop_.root());
20✔
642

643
    // Cleanup
644
    analysis_manager.invalidate_all();
20✔
645

646
    passes::SequenceFusion sf_pass;
20✔
647
    passes::DeadCFGElimination dce_pass;
20✔
648
    bool applies = false;
20✔
649
    do {
36✔
650
        applies = false;
36✔
651
        applies |= dce_pass.run(builder, analysis_manager);
36✔
652
        applies |= sf_pass.run(builder, analysis_manager);
36✔
653
    } while (applies);
36✔
654
};
20✔
655

656
void OutLocalStorage::to_json(nlohmann::json& j) const {
3✔
657
    std::string loop_type;
3✔
658
    if (dynamic_cast<structured_control_flow::For*>(&loop_)) {
3✔
659
        loop_type = "for";
2✔
660
    } else if (dynamic_cast<structured_control_flow::Map*>(&loop_)) {
2✔
661
        loop_type = "map";
1✔
662
    } else {
1✔
663
        throw std::runtime_error("Unsupported loop type for serialization of loop: " + loop_.indvar()->get_name());
×
664
    }
×
665
    j["subgraph"] = {
3✔
666
        {"0", {{"element_id", this->loop_.element_id()}, {"type", loop_type}}},
3✔
667
        {"1", {{"element_id", this->access_node_.element_id()}, {"type", "access_node"}}}
3✔
668
    };
3✔
669
    j["transformation_type"] = this->name();
3✔
670
};
3✔
671

672
OutLocalStorage OutLocalStorage::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& desc) {
1✔
673
    auto loop_id = desc["subgraph"]["0"]["element_id"].get<size_t>();
1✔
674
    auto element = builder.find_element_by_id(loop_id);
1✔
675
    if (!element) {
1✔
676
        throw InvalidTransformationDescriptionException("Element with ID " + std::to_string(loop_id) + " not found.");
×
677
    }
×
678
    auto loop = dynamic_cast<structured_control_flow::StructuredLoop*>(element);
1✔
679

680
    auto access_node = dynamic_cast<
1✔
681
        data_flow::AccessNode*>(builder.find_element_by_id(desc.at("subgraph").at("1").at("element_id").get<size_t>()));
1✔
682
    if (!access_node) {
1✔
683
        throw InvalidTransformationDescriptionException(
×
684
            "Access node with ID " + std::to_string(desc.at("subgraph").at("1").at("element_id").get<size_t>()) +
×
685
            " not found."
×
686
        );
×
687
    }
×
688

689
    return OutLocalStorage(*loop, *access_node);
1✔
690
};
1✔
691

692
} // namespace transformations
693
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc