• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27449564918

12 Jun 2026 11:46PM UTC coverage: 61.331% (-0.02%) from 61.354%
27449564918

push

github

web-flow
adds support for polybench-style pointers in local storage (#757)

* adds support for polybench-style pointers in local storage

* local storage handles containers with opaque pointers

* adds more corner case handling for local storage

* disables flaky softmax

79 of 109 new or added lines in 3 files covered. (72.48%)

11 existing lines in 1 file now uncovered.

36336 of 59246 relevant lines covered (61.33%)

1121.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.65
/opt/src/transformations/out_local_storage.cpp
1
#include "sdfg/transformations/out_local_storage.h"
2

3
#include <cstddef>
4
#include <functional>
5
#include <string>
6

7
#include "sdfg/analysis/memory_layout_analysis.h"
8
#include "sdfg/analysis/users.h"
9
#include "sdfg/builder/structured_sdfg_builder.h"
10
#include "sdfg/data_flow/access_node.h"
11
#include "sdfg/data_flow/library_nodes/barrier_local_node.h"
12
#include "sdfg/data_flow/memlet.h"
13
#include "sdfg/passes/structured_control_flow/dead_cfg_elimination.h"
14
#include "sdfg/passes/structured_control_flow/sequence_fusion.h"
15
#include "sdfg/structured_control_flow/if_else.h"
16
#include "sdfg/structured_control_flow/sequence.h"
17
#include "sdfg/structured_control_flow/structured_loop.h"
18
#include "sdfg/symbolic/symbolic.h"
19
#include "sdfg/targets/gpu/gpu_schedule_type.h"
20
#include "sdfg/types/array.h"
21
#include "sdfg/types/pointer.h"
22
#include "sdfg/types/scalar.h"
23

24
namespace sdfg {
25
namespace transformations {
26

27
OutLocalStorage::OutLocalStorage(
28
    structured_control_flow::StructuredLoop& loop,
29
    const data_flow::AccessNode& access_node,
30
    const types::StorageType& storage_type
31
)
32
    : loop_(loop), access_node_(access_node), container_(access_node.data()), storage_type_(storage_type) {};
30✔
33

34
std::string OutLocalStorage::name() const { return "OutLocalStorage"; };
5✔
35

36
bool OutLocalStorage::can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
28✔
37
    auto& sdfg = builder.subject();
28✔
38
    auto& body = this->loop_.root();
28✔
39

40
    tile_info_ = TileInfo{};
28✔
41

42
    // Criterion: Container must exist
43
    if (!sdfg.exists(this->container_)) {
28✔
44
        return false;
×
45
    }
×
46

47
    auto& type = sdfg.type(this->container_);
28✔
48

49
    // Criterion: Container must be used in the loop body
50
    auto& users = analysis_manager.get<analysis::Users>();
28✔
51
    analysis::UsersView body_users(users, body);
28✔
52
    if (body_users.uses(this->container_).empty()) {
28✔
53
        return false;
2✔
54
    }
2✔
55

56
    // Criterion: Container must have writes (this is OutLocalStorage, not InLocalStorage)
57
    if (body_users.writes(this->container_).empty()) {
26✔
58
        return false;
1✔
59
    }
1✔
60

61
    // Determine if container is also read (read-write vs write-only)
62
    tile_info_.has_read = !body_users.reads(this->container_).empty();
25✔
63

64
    // Handle scalar containers: no tile needed, dimensions stay empty
65
    if (type.type_id() == types::TypeID::Scalar) {
25✔
66
        return true;
1✔
67
    }
1✔
68

69
    // For Array/Pointer types: use MemoryLayoutAnalysis tile group API
70
    if (type.type_id() != types::TypeID::Pointer) {
24✔
71
        return false;
×
72
    }
×
73

74
    auto& mla = analysis_manager.get<analysis::MemoryLayoutAnalysis>();
24✔
75

76
    // Find a representative memlet from the access node to identify its group.
77
    // An access node may have multiple edges belonging to different tile groups.
78
    // We iterate all edges and select the first one whose tile group is valid
79
    // at the target loop level.
80
    const analysis::MemoryTileGroup* group = nullptr;
24✔
81
    auto& dfg = access_node_.get_parent();
24✔
82
    for (auto& memlet : dfg.in_edges(access_node_)) {
24✔
83
        auto* candidate = mla.tile_group_for(loop_, memlet);
11✔
84
        if (candidate) {
11✔
85
            group = candidate;
11✔
86
            break;
11✔
87
        }
11✔
88
    }
11✔
89
    if (!group) {
24✔
90
        for (auto& memlet : dfg.out_edges(access_node_)) {
13✔
91
            auto* candidate = mla.tile_group_for(loop_, memlet);
13✔
92
            if (candidate) {
13✔
93
                group = candidate;
13✔
94
                break;
13✔
95
            }
13✔
96
        }
13✔
97
    }
13✔
98
    if (!group) {
24✔
99
        return false;
×
100
    }
×
101

102
    auto& tile = group->tile;
24✔
103

104
    // Store group memlets for use in apply()
105
    group_memlets_.clear();
24✔
106
    group_memlets_.insert(group->memlets.begin(), group->memlets.end());
24✔
107

108
    // Get overapproximated extents (integer upper bounds)
109
    auto extents = tile.extents_approx();
24✔
110
    if (extents.empty()) {
24✔
111
        return false;
×
112
    }
×
113
    // Reject if any extent depends on an unbounded leading dimension (returned as null
114
    // by extents_approx). Downstream code (substitution, stride computation) would
115
    // dereference these.
116
    for (auto& ext : extents) {
41✔
117
        if (ext.is_null()) return false;
41✔
118
    }
41✔
119

120
    // Store tile info (before substitution, bases/strides stay symbolic)
121
    tile_info_.dimensions = extents;
24✔
122
    tile_info_.bases = tile.min_subset;
24✔
123
    tile_info_.strides = std::vector<symbolic::Expression>(tile.layout.strides().begin(), tile.layout.strides().end());
24✔
124
    tile_info_.offset = tile.layout.offset();
24✔
125

126
    // GPU shared memory: resolve symbolic extents using GPU block sizes and
127
    // require at least one cooperative dimension
128
    if (storage_type_.is_nv_shared()) {
24✔
129
        auto ancestors = ControlFlowNode::parent_chain(loop_);
6✔
130

131
        // Build substitution map: symbolic GPU map bounds → integer block sizes
132
        for (auto* node : ancestors) {
26✔
133
            if (auto* ancestor_map = dynamic_cast<structured_control_flow::Map*>(node)) {
26✔
134
                if (!gpu::is_gpu_schedule(ancestor_map->schedule_type())) {
10✔
135
                    continue;
×
136
                }
×
137
                auto block_size = gpu::gpu_block_size(ancestor_map->schedule_type());
10✔
138
                // Extract symbolic bound from condition: Lt(indvar, BOUND)
139
                auto condition = ancestor_map->condition();
10✔
140
                if (SymEngine::is_a<SymEngine::StrictLessThan>(*condition)) {
10✔
141
                    auto stl = SymEngine::rcp_static_cast<const SymEngine::StrictLessThan>(condition);
10✔
142
                    auto rhs = stl->get_args()[1];
10✔
143
                    auto iter_count = symbolic::sub(rhs, ancestor_map->init());
10✔
144
                    if (!SymEngine::is_a<SymEngine::Integer>(*iter_count)) {
10✔
145
                        // Symbolic bound — substitute with block size in extents and bases
146
                        for (auto& ext : tile_info_.dimensions) {
17✔
147
                            ext = symbolic::simplify(symbolic::subs(ext, iter_count, block_size));
17✔
148
                        }
17✔
149
                        for (auto& base : tile_info_.bases) {
17✔
150
                            base = symbolic::simplify(symbolic::subs(base, iter_count, block_size));
17✔
151
                        }
17✔
152
                    }
10✔
153
                }
10✔
154
            }
10✔
155
        }
26✔
156

157
        // Criterion: All extents must now be provably integer
158
        for (auto& ext : tile_info_.dimensions) {
10✔
159
            if (!SymEngine::is_a<SymEngine::Integer>(*ext)) {
10✔
160
                return false;
2✔
161
            }
2✔
162
        }
10✔
163

164
        // Criterion: At least one cooperative dimension
165
        bool has_cooperative_dim = false;
4✔
166
        for (auto* node : ancestors) {
12✔
167
            if (auto* ancestor_map = dynamic_cast<structured_control_flow::Map*>(node)) {
12✔
168
                if (!gpu::is_gpu_schedule(ancestor_map->schedule_type())) {
6✔
169
                    continue;
×
170
                }
×
171
                bool appears_in_bases = false;
6✔
172
                for (auto& base : tile_info_.bases) {
9✔
173
                    if (symbolic::uses(base, ancestor_map->indvar())) {
9✔
174
                        appears_in_bases = true;
2✔
175
                        break;
2✔
176
                    }
2✔
177
                }
9✔
178
                if (!appears_in_bases) {
6✔
179
                    has_cooperative_dim = true;
4✔
180
                    break;
4✔
181
                }
4✔
182
            }
6✔
183
        }
12✔
184
        if (!has_cooperative_dim) {
4✔
185
            return false;
×
186
        }
×
187
    } else {
18✔
188
        // CPU path: All extents must be provably integer
189
        for (auto& ext : tile_info_.dimensions) {
31✔
190
            if (!SymEngine::is_a<SymEngine::Integer>(*ext)) {
31✔
191
                return false;
×
192
            }
×
193
        }
31✔
194
    }
18✔
195

196
    return true;
22✔
197
}
24✔
198

199
void OutLocalStorage::apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
21✔
200
    auto& sdfg = builder.subject();
21✔
201
    auto& users = analysis_manager.get<analysis::Users>();
21✔
202

203
    auto parent_node = loop_.get_parent();
21✔
204
    auto parent = dynamic_cast<structured_control_flow::Sequence*>(parent_node);
21✔
205
    if (!parent) {
21✔
206
        throw InvalidSDFGException("OutLocalStorage: Parent of loop must be a Sequence!");
×
207
    }
×
208

209
    // Get type information.
210
    //
211
    // For the array path we derive the element type from a representative memlet
212
    // (`group_memlets_` is populated by `can_be_applied`).  This handles the case
213
    // where the container has an *opaque* pointer type (`Pointer()` with no
214
    // pointee) but the memlets carry a more specific base type.
215
    //
216
    // For the scalar path `can_be_applied` returns early without populating
217
    // `group_memlets_`, so we must fall back to the container's declared type —
218
    // which by construction is `Scalar` on this path.
219
    types::Scalar scalar_type = [&]() {
21✔
220
        if (tile_info_.dimensions.empty()) {
21✔
221
            auto& type = sdfg.type(this->container_);
1✔
222
            return types::Scalar(type.primitive_type());
1✔
223
        }
1✔
224
        auto* memlet = *group_memlets_.begin();
20✔
225
        return types::Scalar(memlet->base_type().primitive_type());
20✔
226
    }();
21✔
227
    types::Pointer pointer_type(scalar_type);
21✔
228

229
    // Create local buffer name
230
    local_name_ = builder.find_new_name("__daisy_out_local_storage_" + this->container_);
21✔
231

232
    // ========================================================================
233
    // SCALAR PATH: tile_info_.dimensions is empty
234
    // ========================================================================
235
    if (tile_info_.dimensions.empty()) {
21✔
236
        // Create scalar local buffer
237
        builder.add_container(local_name_, scalar_type);
1✔
238

239
        // Get the access subset from the first user (all scalar, so empty subset)
240
        analysis::UsersView body_users(users, loop_.root());
1✔
241
        auto accesses = body_users.uses(this->container_);
1✔
242
        auto first_access = accesses.at(0);
1✔
243
        auto first_subset = first_access->subsets().at(0);
1✔
244

245
        // The scalar copy-in/copy-out memlets reference the *original* container,
246
        // so they keep its declared base type.  `sdfg.type()` is the canonical
247
        // source — the per-memlet base_type may not exist on the scalar path.
248
        auto& container_type = sdfg.type(this->container_);
1✔
249

250
        // Init block (copy from container to local) - before loop
251
        if (tile_info_.has_read) {
1✔
252
            auto& init_block = builder.add_block_before(*parent, loop_, {}, loop_.debug_info());
1✔
253
            auto& init_src = builder.add_access(init_block, this->container_);
1✔
254
            auto& init_dst = builder.add_access(init_block, local_name_);
1✔
255
            auto& init_tasklet = builder.add_tasklet(init_block, data_flow::TaskletCode::assign, "_out", {"_in"});
1✔
256
            builder.add_computational_memlet(init_block, init_src, init_tasklet, "_in", first_subset, container_type);
1✔
257
            builder.add_computational_memlet(init_block, init_tasklet, "_out", init_dst, {}, scalar_type);
1✔
258
        }
1✔
259

260
        // Writeback block (copy from local to container) - after loop
261
        {
1✔
262
            auto& wb_block = builder.add_block_after(*parent, loop_, {}, loop_.debug_info());
1✔
263
            auto& wb_src = builder.add_access(wb_block, local_name_);
1✔
264
            auto& wb_dst = builder.add_access(wb_block, this->container_);
1✔
265
            auto& wb_tasklet = builder.add_tasklet(wb_block, data_flow::TaskletCode::assign, "_out", {"_in"});
1✔
266
            builder.add_computational_memlet(wb_block, wb_src, wb_tasklet, "_in", {}, scalar_type);
1✔
267
            builder.add_computational_memlet(wb_block, wb_tasklet, "_out", wb_dst, first_subset, container_type);
1✔
268
        }
1✔
269

270
        // Rewrite body accesses to use scalar local
271
        for (auto* user : body_users.uses(this->container_)) {
2✔
272
            auto element = user->element();
2✔
273
            if (auto access = dynamic_cast<data_flow::AccessNode*>(element)) {
2✔
274
                for (auto& iedge : access->get_parent().in_edges(*access)) {
2✔
275
                    auto memlet = &iedge;
1✔
276
                    memlet->set_subset({});
1✔
277
                    memlet->set_base_type(scalar_type);
1✔
278
                }
1✔
279
                for (auto& oedge : access->get_parent().out_edges(*access)) {
2✔
280
                    auto memlet = &oedge;
1✔
281
                    memlet->set_subset({});
1✔
282
                    memlet->set_base_type(scalar_type);
1✔
283
                }
1✔
284
            }
2✔
285
        }
2✔
286

287
        // Replace container name in the loop body
288
        loop_.replace(symbolic::symbol(this->container_), symbolic::symbol(local_name_));
1✔
289
    }
1✔
290
    // ========================================================================
291
    // ARRAY PATH: tile_info_.dimensions is non-empty
292
    // ========================================================================
293
    else {
20✔
294
        // Collect varying dimensions (extent > 1) and compute buffer layout.
295
        // Extent-1 dimensions are degenerate (no loop is needed) and must be
296
        // skipped when sizing the buffer, when creating copy indvars, and when
297
        // linearizing into the local buffer.  The bookkeeping must match what
298
        // `build_original_subset` expects (it indexes copy_indices by varying
299
        // dimension only).
300
        std::vector<size_t> varying_dims;
20✔
301
        std::vector<symbolic::Expression> dim_sizes;
20✔
302
        for (size_t d = 0; d < tile_info_.dimensions.size(); d++) {
55✔
303
            auto& dim_size = tile_info_.dimensions.at(d);
35✔
304
            if (!symbolic::eq(dim_size, symbolic::integer(1))) {
35✔
305
                varying_dims.push_back(d);
20✔
306
                dim_sizes.push_back(dim_size);
20✔
307
            }
20✔
308
        }
35✔
309

310
        // Compute total buffer size
311
        symbolic::Expression total_size = symbolic::integer(1);
20✔
312
        for (auto& ds : dim_sizes) {
20✔
313
            total_size = symbolic::mul(total_size, ds);
20✔
314
        }
20✔
315

316
        // Create the local buffer with specified storage type
317
        types::Array buffer_type(storage_type_, 0, {}, scalar_type, total_size);
20✔
318
        builder.add_container(local_name_, buffer_type);
20✔
319

320
        // Helper: build linearized local index from per-dimension expressions
321
        auto linearize_exprs = [&](const std::vector<symbolic::Expression>& indices) -> symbolic::Expression {
65✔
322
            symbolic::Expression linear_idx = symbolic::integer(0);
65✔
323
            symbolic::Expression stride = symbolic::integer(1);
65✔
324
            for (int i = indices.size() - 1; i >= 0; i--) {
128✔
325
                linear_idx = symbolic::add(linear_idx, symbolic::mul(indices[i], stride));
63✔
326
                stride = symbolic::mul(stride, dim_sizes[i]);
63✔
327
            }
63✔
328
            return linear_idx;
65✔
329
        };
65✔
330

331
        // Helper: build linearized local index from per-dimension indvars (symbols)
332
        auto linearize = [&](const std::vector<symbolic::Symbol>& indvars) -> symbolic::Expression {
29✔
333
            std::vector<symbolic::Expression> exprs(indvars.begin(), indvars.end());
29✔
334
            return linearize_exprs(exprs);
29✔
335
        };
29✔
336

337
        // Helper: build source subset (base[d] + copy_indvar[d]) for original container
338
        auto build_original_subset = [&](const std::vector<symbolic::Expression>& copy_indices) -> data_flow::Subset {
34✔
339
            std::vector<symbolic::Expression> full_indices;
34✔
340
            size_t var_idx = 0;
34✔
341
            for (size_t d = 0; d < tile_info_.dimensions.size(); d++) {
95✔
342
                if (!symbolic::eq(tile_info_.dimensions.at(d), symbolic::integer(1))) {
61✔
343
                    full_indices.push_back(symbolic::add(tile_info_.bases.at(d), copy_indices.at(var_idx++)));
34✔
344
                } else {
34✔
345
                    full_indices.push_back(tile_info_.bases.at(d));
27✔
346
                }
27✔
347
            }
61✔
348

349
            symbolic::Expression linear = tile_info_.offset;
34✔
350
            for (size_t d = 0; d < full_indices.size(); d++) {
95✔
351
                linear = symbolic::add(linear, symbolic::mul(tile_info_.strides.at(d), full_indices.at(d)));
61✔
352
            }
61✔
353
            return {linear};
34✔
354
        };
34✔
355

356
        if (storage_type_.is_nv_shared()) {
20✔
357
            // ============================================================
358
            // GPU COOPERATIVE PATH
359
            // ============================================================
360
            auto ancestors = ControlFlowNode::parent_chain(loop_);
4✔
361

362
            // Collect cooperative GPU dimensions
363
            struct CoopDim {
4✔
364
                symbolic::Symbol indvar;
4✔
365
                symbolic::Integer block_size;
4✔
366
                gpu::GPUDimension dimension;
4✔
367
            };
4✔
368
            std::vector<CoopDim> coop_dims;
4✔
369

370
            for (auto* node : ancestors) {
20✔
371
                if (auto* ancestor_map = dynamic_cast<structured_control_flow::Map*>(node)) {
20✔
372
                    if (!gpu::is_gpu_schedule(ancestor_map->schedule_type())) {
8✔
373
                        continue;
×
374
                    }
×
375
                    bool appears_in_bases = false;
8✔
376
                    for (auto& base : tile_info_.bases) {
11✔
377
                        if (symbolic::uses(base, ancestor_map->indvar())) {
11✔
378
                            appears_in_bases = true;
3✔
379
                            break;
3✔
380
                        }
3✔
381
                    }
11✔
382
                    if (!appears_in_bases) {
8✔
383
                        coop_dims.push_back(
5✔
384
                            {ancestor_map->indvar(),
5✔
385
                             gpu::gpu_block_size(ancestor_map->schedule_type()),
5✔
386
                             gpu::gpu_dimension(ancestor_map->schedule_type())}
5✔
387
                        );
5✔
388
                    }
5✔
389
                }
8✔
390
            }
20✔
391

392
            // Compute total cooperative thread count
393
            symbolic::Expression total_coop_threads = symbolic::integer(1);
4✔
394
            for (auto& cd : coop_dims) {
5✔
395
                total_coop_threads = symbolic::mul(total_coop_threads, cd.block_size);
5✔
396
            }
5✔
397

398
            // Flatten cooperative thread index
399
            symbolic::Expression coop_flat = symbolic::integer(0);
4✔
400
            symbolic::Expression coop_stride = symbolic::integer(1);
4✔
401
            for (int i = coop_dims.size() - 1; i >= 0; i--) {
9✔
402
                coop_flat = symbolic::add(coop_flat, symbolic::mul(coop_dims[i].indvar, coop_stride));
5✔
403
                coop_stride = symbolic::mul(coop_stride, coop_dims[i].block_size);
5✔
404
            }
5✔
405

406
            // INIT: barrier → cooperative copy-in → barrier (if has_read)
407
            if (tile_info_.has_read) {
4✔
408
                // Barrier before init
409
                auto& barrier_block1 = builder.add_block_before(*parent, loop_, {}, loop_.debug_info());
1✔
410
                builder.add_library_node<data_flow::BarrierLocalNode>(barrier_block1, {});
1✔
411

412
                // Cooperative copy-in loop
413
                auto idx_name = builder.find_new_name("__daisy_ols_coop_init_" + this->container_);
1✔
414
                types::Scalar idx_type(types::PrimitiveType::UInt64);
1✔
415
                builder.add_container(idx_name, idx_type);
1✔
416
                auto idx_var = symbolic::symbol(idx_name);
1✔
417

418
                auto& init_loop = builder.add_map_before(
1✔
419
                    *parent,
1✔
420
                    loop_,
1✔
421
                    idx_var,
1✔
422
                    symbolic::Lt(idx_var, total_size),
1✔
423
                    coop_flat,
1✔
424
                    symbolic::add(idx_var, total_coop_threads),
1✔
425
                    structured_control_flow::ScheduleType_Sequential::create(),
1✔
426
                    {},
1✔
427
                    loop_.debug_info()
1✔
428
                );
1✔
429

430
                auto& init_block = builder.add_block(init_loop.root());
1✔
431
                auto& init_src = builder.add_access(init_block, this->container_);
1✔
432
                auto& init_dst = builder.add_access(init_block, local_name_);
1✔
433
                auto& init_tasklet = builder.add_tasklet(init_block, data_flow::TaskletCode::assign, "_out", {"_in"});
1✔
434

435
                // Decompose idx_var into per-dim indices over varying dims only
436
                std::vector<symbolic::Expression> init_indices;
1✔
437
                symbolic::Expression remainder = idx_var;
1✔
438
                for (size_t i = 0; i < dim_sizes.size(); i++) {
2✔
439
                    if (i < dim_sizes.size() - 1) {
1✔
UNCOV
440
                        symbolic::Expression divisor = symbolic::integer(1);
×
NEW
441
                        for (size_t j = i + 1; j < dim_sizes.size(); j++) {
×
NEW
442
                            divisor = symbolic::mul(divisor, dim_sizes[j]);
×
UNCOV
443
                        }
×
UNCOV
444
                        init_indices.push_back(symbolic::div(remainder, divisor));
×
UNCOV
445
                        remainder = symbolic::mod(remainder, divisor);
×
446
                    } else {
1✔
447
                        init_indices.push_back(remainder);
1✔
448
                    }
1✔
449
                }
1✔
450

451
                auto init_src_subset = build_original_subset(init_indices);
1✔
452
                builder
1✔
453
                    .add_computational_memlet(init_block, init_src, init_tasklet, "_in", init_src_subset, pointer_type);
1✔
454
                builder.add_computational_memlet(init_block, init_tasklet, "_out", init_dst, {idx_var}, buffer_type);
1✔
455

456
                // Barrier after init
457
                auto& barrier_block2 = builder.add_block_before(*parent, loop_, {}, loop_.debug_info());
1✔
458
                builder.add_library_node<data_flow::BarrierLocalNode>(barrier_block2, {});
1✔
459
            }
1✔
460

461
            // WRITEBACK: barrier → cooperative copy-out → barrier
462
            {
4✔
463
                // Barrier before writeback
464
                auto& barrier_block3 = builder.add_block_after(*parent, loop_, {}, loop_.debug_info());
4✔
465
                builder.add_library_node<data_flow::BarrierLocalNode>(barrier_block3, {});
4✔
466

467
                // Cooperative writeback loop
468
                auto idx_name = builder.find_new_name("__daisy_ols_coop_wb_" + this->container_);
4✔
469
                types::Scalar idx_type(types::PrimitiveType::UInt64);
4✔
470
                builder.add_container(idx_name, idx_type);
4✔
471
                auto idx_var = symbolic::symbol(idx_name);
4✔
472

473
                auto& wb_loop = builder.add_map_after(
4✔
474
                    *parent,
4✔
475
                    loop_,
4✔
476
                    idx_var,
4✔
477
                    symbolic::Lt(idx_var, total_size),
4✔
478
                    coop_flat,
4✔
479
                    symbolic::add(idx_var, total_coop_threads),
4✔
480
                    structured_control_flow::ScheduleType_Sequential::create(),
4✔
481
                    {},
4✔
482
                    loop_.debug_info()
4✔
483
                );
4✔
484

485
                auto& wb_block = builder.add_block(wb_loop.root());
4✔
486
                auto& wb_src = builder.add_access(wb_block, local_name_);
4✔
487
                auto& wb_dst = builder.add_access(wb_block, this->container_);
4✔
488
                auto& wb_tasklet = builder.add_tasklet(wb_block, data_flow::TaskletCode::assign, "_out", {"_in"});
4✔
489

490
                // Decompose idx_var into per-dim indices over varying dims only
491
                std::vector<symbolic::Expression> wb_indices;
4✔
492
                symbolic::Expression remainder = idx_var;
4✔
493
                for (size_t i = 0; i < dim_sizes.size(); i++) {
8✔
494
                    if (i < dim_sizes.size() - 1) {
4✔
UNCOV
495
                        symbolic::Expression divisor = symbolic::integer(1);
×
NEW
496
                        for (size_t j = i + 1; j < dim_sizes.size(); j++) {
×
NEW
497
                            divisor = symbolic::mul(divisor, dim_sizes[j]);
×
UNCOV
498
                        }
×
UNCOV
499
                        wb_indices.push_back(symbolic::div(remainder, divisor));
×
UNCOV
500
                        remainder = symbolic::mod(remainder, divisor);
×
501
                    } else {
4✔
502
                        wb_indices.push_back(remainder);
4✔
503
                    }
4✔
504
                }
4✔
505

506
                auto wb_dst_subset = build_original_subset(wb_indices);
4✔
507
                builder.add_computational_memlet(wb_block, wb_src, wb_tasklet, "_in", {idx_var}, buffer_type);
4✔
508
                builder.add_computational_memlet(wb_block, wb_tasklet, "_out", wb_dst, wb_dst_subset, pointer_type);
4✔
509

510
                // Barrier after writeback
511
                auto& barrier_block4 = builder.add_block_after(*parent, loop_, {}, loop_.debug_info());
4✔
512
                builder.add_library_node<data_flow::BarrierLocalNode>(barrier_block4, {});
4✔
513
            }
4✔
514
        } else {
16✔
515
            // ============================================================
516
            // CPU SEQUENTIAL PATH
517
            // ============================================================
518
            if (tile_info_.has_read) {
16✔
519
                std::vector<symbolic::Symbol> init_indvars;
13✔
520
                structured_control_flow::Sequence* init_scope = parent;
13✔
521
                bool first_init_loop = true;
13✔
522

523
                for (size_t i = 0; i < varying_dims.size(); i++) {
26✔
524
                    size_t d = varying_dims[i];
13✔
525
                    auto indvar_name =
13✔
526
                        builder.find_new_name("__daisy_ols_init_" + this->container_ + "_d" + std::to_string(d));
13✔
527
                    types::Scalar indvar_type(types::PrimitiveType::UInt64);
13✔
528
                    builder.add_container(indvar_name, indvar_type);
13✔
529
                    auto indvar = symbolic::symbol(indvar_name);
13✔
530
                    init_indvars.push_back(indvar);
13✔
531

532
                    auto init = symbolic::integer(0);
13✔
533
                    auto condition = symbolic::Lt(indvar, dim_sizes[i]);
13✔
534
                    auto update = symbolic::add(indvar, symbolic::integer(1));
13✔
535

536
                    if (first_init_loop) {
13✔
537
                        auto& init_loop = builder.add_map_before(
8✔
538
                            *init_scope,
8✔
539
                            loop_,
8✔
540
                            indvar,
8✔
541
                            condition,
8✔
542
                            init,
8✔
543
                            update,
8✔
544
                            structured_control_flow::ScheduleType_Sequential::create(),
8✔
545
                            {},
8✔
546
                            loop_.debug_info()
8✔
547
                        );
8✔
548
                        init_scope = &init_loop.root();
8✔
549
                        first_init_loop = false;
8✔
550
                    } else {
8✔
551
                        auto& init_loop = builder.add_map(
5✔
552
                            *init_scope,
5✔
553
                            indvar,
5✔
554
                            condition,
5✔
555
                            init,
5✔
556
                            update,
5✔
557
                            structured_control_flow::ScheduleType_Sequential::create(),
5✔
558
                            {},
5✔
559
                            loop_.debug_info()
5✔
560
                        );
5✔
561
                        init_scope = &init_loop.root();
5✔
562
                    }
5✔
563
                }
13✔
564

565
                // Create init copy block
566
                auto& init_block = builder.add_block(*init_scope);
13✔
567
                auto& init_src = builder.add_access(init_block, this->container_);
13✔
568
                auto& init_dst = builder.add_access(init_block, local_name_);
13✔
569
                auto& init_tasklet = builder.add_tasklet(init_block, data_flow::TaskletCode::assign, "_out", {"_in"});
13✔
570

571
                std::vector<symbolic::Expression> init_exprs(init_indvars.begin(), init_indvars.end());
13✔
572
                auto init_src_subset = build_original_subset(init_exprs);
13✔
573
                data_flow::Subset init_dst_subset = {linearize(init_indvars)};
13✔
574

575
                builder
13✔
576
                    .add_computational_memlet(init_block, init_src, init_tasklet, "_in", init_src_subset, pointer_type);
13✔
577
                builder
13✔
578
                    .add_computational_memlet(init_block, init_tasklet, "_out", init_dst, init_dst_subset, buffer_type);
13✔
579
            }
13✔
580

581
            // Writeback Maps
582
            {
16✔
583
                std::vector<symbolic::Symbol> wb_indvars;
16✔
584
                structured_control_flow::Sequence* wb_scope = parent;
16✔
585
                bool first_wb_loop = true;
16✔
586

587
                for (size_t i = 0; i < varying_dims.size(); i++) {
32✔
588
                    size_t d = varying_dims[i];
16✔
589
                    auto indvar_name =
16✔
590
                        builder.find_new_name("__daisy_ols_wb_" + this->container_ + "_d" + std::to_string(d));
16✔
591
                    types::Scalar indvar_type(types::PrimitiveType::UInt64);
16✔
592
                    builder.add_container(indvar_name, indvar_type);
16✔
593
                    auto indvar = symbolic::symbol(indvar_name);
16✔
594
                    wb_indvars.push_back(indvar);
16✔
595

596
                    auto init = symbolic::integer(0);
16✔
597
                    auto condition = symbolic::Lt(indvar, dim_sizes[i]);
16✔
598
                    auto update = symbolic::add(indvar, symbolic::integer(1));
16✔
599

600
                    if (first_wb_loop) {
16✔
601
                        auto& wb_loop = builder.add_map_after(
11✔
602
                            *wb_scope,
11✔
603
                            loop_,
11✔
604
                            indvar,
11✔
605
                            condition,
11✔
606
                            init,
11✔
607
                            update,
11✔
608
                            structured_control_flow::ScheduleType_Sequential::create(),
11✔
609
                            {},
11✔
610
                            loop_.debug_info()
11✔
611
                        );
11✔
612
                        wb_scope = &wb_loop.root();
11✔
613
                        first_wb_loop = false;
11✔
614
                    } else {
11✔
615
                        auto& wb_loop = builder.add_map(
5✔
616
                            *wb_scope,
5✔
617
                            indvar,
5✔
618
                            condition,
5✔
619
                            init,
5✔
620
                            update,
5✔
621
                            structured_control_flow::ScheduleType_Sequential::create(),
5✔
622
                            {},
5✔
623
                            loop_.debug_info()
5✔
624
                        );
5✔
625
                        wb_scope = &wb_loop.root();
5✔
626
                    }
5✔
627
                }
16✔
628

629
                // Create writeback copy block
630
                auto& wb_block = builder.add_block(*wb_scope);
16✔
631
                auto& wb_src = builder.add_access(wb_block, local_name_);
16✔
632
                auto& wb_dst = builder.add_access(wb_block, this->container_);
16✔
633
                auto& wb_tasklet = builder.add_tasklet(wb_block, data_flow::TaskletCode::assign, "_out", {"_in"});
16✔
634

635
                std::vector<symbolic::Expression> wb_exprs(wb_indvars.begin(), wb_indvars.end());
16✔
636
                data_flow::Subset wb_src_subset = {linearize(wb_indvars)};
16✔
637
                auto wb_dst_subset = build_original_subset(wb_exprs);
16✔
638

639
                builder.add_computational_memlet(wb_block, wb_src, wb_tasklet, "_in", wb_src_subset, buffer_type);
16✔
640
                builder.add_computational_memlet(wb_block, wb_tasklet, "_out", wb_dst, wb_dst_subset, pointer_type);
16✔
641
            }
16✔
642
        }
16✔
643

644
        // ==================================================================
645
        // Update accesses in the main loop to use the local buffer
646
        // ==================================================================
647
        auto& mla = analysis_manager.get<analysis::MemoryLayoutAnalysis>();
20✔
648

649
        // Recursive helper to traverse all blocks in the loop body
650
        std::function<void(structured_control_flow::ControlFlowNode&)> rewrite_accesses;
20✔
651
        rewrite_accesses = [&](structured_control_flow::ControlFlowNode& node) {
63✔
652
            if (auto* block = dynamic_cast<structured_control_flow::Block*>(&node)) {
63✔
653
                auto& dfg = block->dataflow();
23✔
654

655
                // Collect access nodes to process (avoid iterator invalidation when splitting)
656
                std::vector<data_flow::AccessNode*> access_nodes;
23✔
657
                for (auto* access_node : dfg.data_nodes()) {
65✔
658
                    if (access_node->data() == this->container_) {
65✔
659
                        access_nodes.push_back(access_node);
36✔
660
                    }
36✔
661
                }
65✔
662

663
                for (auto* access : access_nodes) {
36✔
664
                    // Classify memlets: group vs non-group
665
                    struct MemletRewrite {
36✔
666
                        data_flow::Memlet* memlet;
36✔
667
                        data_flow::Subset local_subset;
36✔
668
                        bool is_outgoing;
36✔
669
                    };
36✔
670
                    std::vector<MemletRewrite> group_rewrites;
36✔
671
                    bool all_in_group = true;
36✔
672

673
                    // Outgoing memlets (reads from this access node)
674
                    for (auto& memlet : dfg.out_edges(*access)) {
36✔
675
                        if (group_memlets_.count(&memlet) == 0) {
15✔
NEW
676
                            all_in_group = false;
×
677
                            continue;
×
678
                        }
×
679
                        auto* acc = mla.access(memlet);
15✔
680
                        if (acc && acc->subset.size() == tile_info_.dimensions.size()) {
15✔
681
                            std::vector<symbolic::Expression> local_indices;
15✔
682
                            for (size_t d = 0; d < tile_info_.dimensions.size(); d++) {
43✔
683
                                if (!symbolic::eq(tile_info_.dimensions.at(d), symbolic::integer(1))) {
28✔
684
                                    local_indices.push_back(symbolic::sub(acc->subset.at(d), tile_info_.bases.at(d)));
14✔
685
                                }
14✔
686
                            }
28✔
687
                            symbolic::Expression linear_idx = linearize_exprs(local_indices);
15✔
688
                            group_rewrites.push_back({&memlet, {linear_idx}, true});
15✔
689
                        } else {
15✔
690
                            // Memlet is claimed by the group but we cannot rewrite it (no
691
                            // delinearized access info). Leaving it as the original container
692
                            // would create a half-renamed access node. Bail out of renaming
693
                            // to keep the SDFG consistent.
NEW
694
                            all_in_group = false;
×
UNCOV
695
                        }
×
696
                    }
15✔
697
                    // Incoming memlets (writes to this access node)
698
                    for (auto& memlet : dfg.in_edges(*access)) {
36✔
699
                        if (group_memlets_.count(&memlet) == 0) {
21✔
NEW
700
                            all_in_group = false;
×
701
                            continue;
×
702
                        }
×
703
                        auto* acc = mla.access(memlet);
21✔
704
                        if (acc && acc->subset.size() == tile_info_.dimensions.size()) {
21✔
705
                            std::vector<symbolic::Expression> local_indices;
21✔
706
                            for (size_t d = 0; d < tile_info_.dimensions.size(); d++) {
58✔
707
                                if (!symbolic::eq(tile_info_.dimensions.at(d), symbolic::integer(1))) {
37✔
708
                                    local_indices.push_back(symbolic::sub(acc->subset.at(d), tile_info_.bases.at(d)));
20✔
709
                                }
20✔
710
                            }
37✔
711
                            symbolic::Expression linear_idx = linearize_exprs(local_indices);
21✔
712
                            group_rewrites.push_back({&memlet, {linear_idx}, false});
21✔
713
                        } else {
21✔
NEW
714
                            all_in_group = false;
×
UNCOV
715
                        }
×
716
                    }
21✔
717

718
                    if (group_rewrites.empty()) continue;
36✔
719

720
                    if (all_in_group) {
36✔
721
                        // Simple case: all memlets in group → rewrite in-place and rename
722
                        for (auto& rw : group_rewrites) {
36✔
723
                            rw.memlet->set_subset(rw.local_subset);
36✔
724
                            rw.memlet->set_base_type(buffer_type);
36✔
725
                        }
36✔
726
                        access->data(local_name_);
36✔
727
                    } else {
36✔
728
                        // Mixed case: split — create new local access node, redirect group memlets
NEW
729
                        auto& local_access = builder.add_access(*block, local_name_);
×
NEW
730
                        for (auto& rw : group_rewrites) {
×
NEW
731
                            if (rw.is_outgoing) {
×
732
                                // outgoing: access→tasklet  →  local_access→tasklet
NEW
733
                                auto& dst_node = rw.memlet->dst();
×
NEW
734
                                auto dst_conn = rw.memlet->dst_conn();
×
NEW
735
                                builder.remove_memlet(*block, *rw.memlet);
×
NEW
736
                                builder.add_memlet(
×
NEW
737
                                    *block, local_access, "void", dst_node, dst_conn, rw.local_subset, buffer_type, {}
×
NEW
738
                                );
×
NEW
739
                            } else {
×
740
                                // incoming: tasklet→access  →  tasklet→local_access
NEW
741
                                auto& src_node = rw.memlet->src();
×
NEW
742
                                auto src_conn = rw.memlet->src_conn();
×
NEW
743
                                builder.remove_memlet(*block, *rw.memlet);
×
NEW
744
                                builder.add_memlet(
×
NEW
745
                                    *block, src_node, src_conn, local_access, "void", rw.local_subset, buffer_type, {}
×
NEW
746
                                );
×
NEW
747
                            }
×
NEW
748
                        }
×
UNCOV
749
                    }
×
750
                }
36✔
751
            } else if (auto* seq = dynamic_cast<structured_control_flow::Sequence*>(&node)) {
40✔
752
                for (size_t i = 0; i < seq->size(); i++) {
63✔
753
                    rewrite_accesses(seq->at(i).first);
33✔
754
                }
33✔
755
            } else if (auto* loop = dynamic_cast<structured_control_flow::StructuredLoop*>(&node)) {
30✔
756
                rewrite_accesses(loop->root());
10✔
757
            } else if (auto* if_else = dynamic_cast<structured_control_flow::IfElse*>(&node)) {
10✔
758
                for (size_t i = 0; i < if_else->size(); i++) {
×
759
                    rewrite_accesses(if_else->at(i).first);
×
760
                }
×
761
            }
×
762
        };
63✔
763
        rewrite_accesses(loop_.root());
20✔
764
    }
20✔
765

766
    // Cleanup
767
    analysis_manager.invalidate_all();
21✔
768

769
    passes::SequenceFusion sf_pass;
21✔
770
    passes::DeadCFGElimination dce_pass;
21✔
771
    bool applies = false;
21✔
772
    do {
21✔
773
        applies = false;
21✔
774
        applies |= dce_pass.run(builder, analysis_manager);
21✔
775
        applies |= sf_pass.run(builder, analysis_manager);
21✔
776
    } while (applies);
21✔
777
};
21✔
778

779
void OutLocalStorage::to_json(nlohmann::json& j) const {
3✔
780
    std::string loop_type;
3✔
781
    if (dynamic_cast<structured_control_flow::For*>(&loop_)) {
3✔
782
        loop_type = "for";
2✔
783
    } else if (dynamic_cast<structured_control_flow::Map*>(&loop_)) {
2✔
784
        loop_type = "map";
1✔
785
    } else {
1✔
786
        throw std::runtime_error("Unsupported loop type for serialization of loop: " + loop_.indvar()->get_name());
×
787
    }
×
788
    j["subgraph"] = {
3✔
789
        {"0", {{"element_id", this->loop_.element_id()}, {"type", loop_type}}},
3✔
790
        {"1", {{"element_id", this->access_node_.element_id()}, {"type", "access_node"}}}
3✔
791
    };
3✔
792
    j["transformation_type"] = this->name();
3✔
793
};
3✔
794

795
OutLocalStorage OutLocalStorage::from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& desc) {
1✔
796
    auto loop_id = desc["subgraph"]["0"]["element_id"].get<size_t>();
1✔
797
    auto element = builder.find_element_by_id(loop_id);
1✔
798
    if (!element) {
1✔
799
        throw InvalidTransformationDescriptionException("Element with ID " + std::to_string(loop_id) + " not found.");
×
800
    }
×
801
    auto loop = dynamic_cast<structured_control_flow::StructuredLoop*>(element);
1✔
802

803
    auto access_node = dynamic_cast<
1✔
804
        data_flow::AccessNode*>(builder.find_element_by_id(desc.at("subgraph").at("1").at("element_id").get<size_t>()));
1✔
805
    if (!access_node) {
1✔
806
        throw InvalidTransformationDescriptionException(
×
807
            "Access node with ID " + std::to_string(desc.at("subgraph").at("1").at("element_id").get<size_t>()) +
×
808
            " not found."
×
809
        );
×
810
    }
×
811

812
    return OutLocalStorage(*loop, *access_node);
1✔
813
};
1✔
814

815
} // namespace transformations
816
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc