• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 24101568844

07 Apr 2026 07:58PM UTC coverage: 64.834% (+0.04%) from 64.797%
24101568844

Pull #653

github

web-flow
Merge cfe8b8db4 into 0e7f1388b
Pull Request #653: adds memlet simplification for contiguous memory accesses

121 of 157 new or added lines in 4 files covered. (77.07%)

3 existing lines in 1 file now uncovered.

28958 of 44665 relevant lines covered (64.83%)

604.48 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.97
/sdfg/src/passes/dataflow/memlet_simplification.cpp
1
#include "sdfg/passes/dataflow/memlet_simplification.h"
2

3
#include <algorithm>
4
#include <optional>
5
#include <vector>
6

7
#include "sdfg/data_flow/data_flow_graph.h"
8
#include "sdfg/data_flow/memlet.h"
9
#include "sdfg/structured_control_flow/block.h"
10
#include "sdfg/structured_control_flow/map.h"
11
#include "sdfg/structured_control_flow/sequence.h"
12
#include "sdfg/symbolic/symbolic.h"
13
#include "sdfg/visitor/structured_sdfg_visitor.h"
14

15
namespace sdfg {
16
namespace passes {
17

18
namespace {
19

20
/**
21
 * @brief Represents a single term in a mixed-radix decomposition
22
 *
23
 * A term has the form: stride * ((base / divisor) % modulus)
24
 * Special cases:
25
 *   - Outermost: stride * (base / divisor), modulus = infinity (represented as 0)
26
 *   - Innermost: (base % modulus), stride = 1, divisor = 1
27
 */
28
struct MixedRadixTerm {
29
    symbolic::Expression base; // The base index variable
30
    int64_t stride; // Multiplier (product of dims after this one)
31
    int64_t divisor; // What to divide base by
32
    int64_t modulus; // What to mod by (0 = no modulus, i.e., outermost)
33
};
34

35
/**
36
 * @brief Checks if an expression is the idiv function and extracts its arguments
37
 */
38
std::optional<std::pair<symbolic::Expression, int64_t>> parse_idiv(const symbolic::Expression& expr) {
9✔
39
    if (!SymEngine::is_a<SymEngine::FunctionSymbol>(*expr)) {
9✔
40
        return std::nullopt;
3✔
41
    }
3✔
42
    auto func = SymEngine::rcp_static_cast<const SymEngine::FunctionSymbol>(expr);
6✔
43
    if (func->get_name() != "idiv") {
6✔
NEW
44
        return std::nullopt;
×
NEW
45
    }
×
46
    auto args = func->get_args();
6✔
47
    if (args.size() != 2) {
6✔
NEW
48
        return std::nullopt;
×
NEW
49
    }
×
50
    if (!SymEngine::is_a<SymEngine::Integer>(*args[1])) {
6✔
NEW
51
        return std::nullopt;
×
NEW
52
    }
×
53
    auto divisor = SymEngine::rcp_static_cast<const SymEngine::Integer>(args[1])->as_int();
6✔
54
    return std::make_pair(args[0], divisor);
6✔
55
}
6✔
56

57
/**
58
 * @brief Checks if an expression is the imod function and extracts its arguments
59
 */
60
std::optional<std::pair<symbolic::Expression, int64_t>> parse_imod(const symbolic::Expression& expr) {
9✔
61
    if (!SymEngine::is_a<SymEngine::FunctionSymbol>(*expr)) {
9✔
NEW
62
        return std::nullopt;
×
NEW
63
    }
×
64
    auto func = SymEngine::rcp_static_cast<const SymEngine::FunctionSymbol>(expr);
9✔
65
    if (func->get_name() != "imod") {
9✔
66
        return std::nullopt;
3✔
67
    }
3✔
68
    auto args = func->get_args();
6✔
69
    if (args.size() != 2) {
6✔
NEW
70
        return std::nullopt;
×
NEW
71
    }
×
72
    if (!SymEngine::is_a<SymEngine::Integer>(*args[1])) {
6✔
NEW
73
        return std::nullopt;
×
NEW
74
    }
×
75
    auto modulus = SymEngine::rcp_static_cast<const SymEngine::Integer>(args[1])->as_int();
6✔
76
    return std::make_pair(args[0], modulus);
6✔
77
}
6✔
78

79
/**
80
 * @brief Parses a single term of the mixed-radix expression
81
 *
82
 * Handles:
83
 *   - stride * imod(idiv(base, divisor), modulus)  (middle terms)
84
 *   - stride * idiv(base, divisor)                  (outermost term, no mod)
85
 *   - imod(base, modulus)                           (innermost term, stride=1, divisor=1)
86
 *   - imod(idiv(base, divisor), modulus)            (stride=1 middle term)
87
 */
88
std::optional<MixedRadixTerm> parse_term(const symbolic::Expression& term, const symbolic::Symbol& expected_base) {
9✔
89
    MixedRadixTerm result;
9✔
90
    result.stride = 1;
9✔
91
    result.divisor = 1;
9✔
92
    result.modulus = 0;
9✔
93

94
    symbolic::Expression inner = term;
9✔
95

96
    // Check if term is stride * something
97
    if (SymEngine::is_a<SymEngine::Mul>(*term)) {
9✔
98
        auto mul = SymEngine::rcp_static_cast<const SymEngine::Mul>(term);
6✔
99
        auto args = mul->get_args();
6✔
100

101
        // Look for an integer multiplier (stride)
102
        symbolic::Expression non_int_part = symbolic::one();
6✔
103
        bool found_stride = false;
6✔
104

105
        for (const auto& arg : args) {
12✔
106
            if (SymEngine::is_a<SymEngine::Integer>(*arg) && !found_stride) {
12✔
107
                result.stride = SymEngine::rcp_static_cast<const SymEngine::Integer>(arg)->as_int();
6✔
108
                found_stride = true;
6✔
109
            } else {
6✔
110
                non_int_part = SymEngine::mul(non_int_part, arg);
6✔
111
            }
6✔
112
        }
12✔
113

114
        if (found_stride) {
6✔
115
            inner = non_int_part;
6✔
116
        }
6✔
117
    }
6✔
118

119
    // Now inner should be one of:
120
    //   - imod(idiv(base, divisor), modulus)
121
    //   - idiv(base, divisor)  (outermost)
122
    //   - imod(base, modulus)  (innermost)
123
    //   - just base            (trivial case, stride=divisor=1, no mod)
124

125
    // Try: imod(something, modulus)
126
    if (auto mod_result = parse_imod(inner)) {
9✔
127
        result.modulus = mod_result->second;
6✔
128
        inner = mod_result->first;
6✔
129
    }
6✔
130

131
    // Try: idiv(base, divisor)
132
    if (auto div_result = parse_idiv(inner)) {
9✔
133
        result.divisor = div_result->second;
6✔
134
        inner = div_result->first;
6✔
135
    }
6✔
136

137
    // Now inner should be the base symbol
138
    if (!SymEngine::is_a<SymEngine::Symbol>(*inner)) {
9✔
NEW
139
        return std::nullopt;
×
NEW
140
    }
×
141

142
    auto sym = SymEngine::rcp_static_cast<const SymEngine::Symbol>(inner);
9✔
143
    if (!symbolic::eq(sym, expected_base)) {
9✔
NEW
144
        return std::nullopt;
×
NEW
145
    }
×
146

147
    result.base = sym;
9✔
148
    return result;
9✔
149
}
9✔
150

151
} // anonymous namespace
152

153
MemletSimplification::
154
    MemletSimplification(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager)
155
    : visitor::NonStoppingStructuredSDFGVisitor(builder, analysis_manager) {}
9✔
156

157
std::optional<symbolic::Expression> MemletSimplification::
158
    try_simplify_mixed_radix(const symbolic::Expression& expr, const symbolic::Symbol& expected_base) {
11✔
159
    // Handle trivial case: expression is just the base symbol
160
    if (SymEngine::is_a<SymEngine::Symbol>(*expr)) {
11✔
161
        if (symbolic::eq(expr, expected_base)) {
8✔
162
            return expr;
8✔
163
        }
8✔
NEW
164
        return std::nullopt;
×
165
    }
8✔
166

167
    // Must be an Add expression
168
    if (!SymEngine::is_a<SymEngine::Add>(*expr)) {
3✔
NEW
169
        return std::nullopt;
×
NEW
170
    }
×
171

172
    auto add = SymEngine::rcp_static_cast<const SymEngine::Add>(expr);
3✔
173
    auto args = add->get_args();
3✔
174

175
    if (args.size() < 2) {
3✔
NEW
176
        return std::nullopt;
×
NEW
177
    }
×
178

179
    // Parse all terms
180
    std::vector<MixedRadixTerm> terms;
3✔
181
    for (const auto& arg : args) {
9✔
182
        auto parsed = parse_term(arg, expected_base);
9✔
183
        if (!parsed) {
9✔
NEW
184
            return std::nullopt;
×
NEW
185
        }
×
186
        terms.push_back(*parsed);
9✔
187
    }
9✔
188

189
    // Sort by divisor descending (outermost first, innermost last)
190
    std::sort(terms.begin(), terms.end(), [](const MixedRadixTerm& a, const MixedRadixTerm& b) {
6✔
191
        return a.divisor > b.divisor;
6✔
192
    });
6✔
193

194
    // Verify chain property:
195
    // 1. stride[k] == divisor[k] for all k
196
    // 2. For each k (except innermost): divisor[k] == modulus[k+1] * divisor[k+1]
197
    // 3. Innermost divisor must be 1 (no division needed for last dimension)
198

199
    for (size_t k = 0; k < terms.size(); ++k) {
12✔
200
        // Check stride == divisor
201
        if (terms[k].stride != terms[k].divisor) {
9✔
NEW
202
            return std::nullopt;
×
NEW
203
        }
×
204

205
        if (k < terms.size() - 1) {
9✔
206
            // Outermost term (k=0) can have no modulus, others must have it
207
            if (terms[k].modulus == 0 && k != 0) {
6✔
NEW
208
                return std::nullopt;
×
NEW
209
            }
×
210
            // Verify: divisor[k] == modulus[k+1] * divisor[k+1]
211
            if (terms[k].divisor != terms[k + 1].modulus * terms[k + 1].divisor) {
6✔
NEW
212
                return std::nullopt;
×
NEW
213
            }
×
214
        }
6✔
215
    }
9✔
216

217
    // Innermost term must have divisor == 1
218
    if (terms.back().divisor != 1) {
3✔
NEW
219
        return std::nullopt;
×
NEW
220
    }
×
221

222
    // Innermost term must have modulus (since it's base % modulus)
223
    if (terms.back().modulus == 0) {
3✔
NEW
224
        return std::nullopt;
×
NEW
225
    }
×
226

227
    // All checks passed - the expression equals the base index
228
    return expected_base;
3✔
229
}
3✔
230

231
bool MemletSimplification::accept(structured_control_flow::Map& map) {
13✔
232
    // Only process Maps in loop normal form (init=0, stride=1)
233
    if (!map.is_loop_normal_form()) {
13✔
NEW
234
        return false;
×
NEW
235
    }
×
236

237
    auto indvar = map.indvar();
13✔
238
    bool applied = false;
13✔
239

240
    // Walk immediate blocks in the map's body
241
    auto& body = map.root();
13✔
242
    for (size_t i = 0; i < body.size(); ++i) {
26✔
243
        auto* block = dynamic_cast<structured_control_flow::Block*>(&body.at(i).first);
13✔
244
        if (!block) {
13✔
245
            continue;
2✔
246
        }
2✔
247

248
        // Process all memlets in this block
249
        auto& dfg = block->dataflow();
11✔
250
        for (auto& memlet : dfg.edges()) {
19✔
251
            // Only process flat pointer memlets (single-element subset)
252
            if (memlet.subset().size() != 1) {
19✔
253
                continue;
8✔
254
            }
8✔
255

256
            auto& subset_expr = memlet.subset()[0];
11✔
257

258
            // Try to simplify the index expression
259
            auto simplified = try_simplify_mixed_radix(subset_expr, indvar);
11✔
260
            if (simplified && !symbolic::eq(*simplified, subset_expr)) {
11✔
261
                memlet.set_subset({*simplified});
3✔
262
                applied = true;
3✔
263
            }
3✔
264
        }
11✔
265
    }
11✔
266

267
    return applied;
13✔
268
}
13✔
269

270
} // namespace passes
271
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc