• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22020750556

14 Feb 2026 04:38PM UTC coverage: 64.828% (-1.5%) from 66.315%
22020750556

Pull #524

github

web-flow
Merge 2784aa264 into 9d01cacd5
Pull Request #524: Native Tensor Support - Step 2: Use tensor types on memlets of tensor nodes

245 of 570 new or added lines in 24 files covered. (42.98%)

458 existing lines in 18 files now uncovered.

23080 of 35602 relevant lines covered (64.83%)

371.57 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

55.65
/sdfg/src/data_flow/memlet.cpp
1
#include <sdfg/data_flow/memlet.h>
2

3
#include "sdfg/data_flow/library_node.h"
4
#include "sdfg/data_flow/tasklet.h"
5
#include "sdfg/function.h"
6
#include "sdfg/symbolic/symbolic.h"
7
#include "sdfg/types/type.h"
8
#include "sdfg/types/utils.h"
9

10
namespace sdfg {
11
namespace data_flow {
12

13
Memlet::Memlet(
14
    size_t element_id,
15
    const DebugInfo& debug_info,
16
    const graph::Edge& edge,
17
    DataFlowGraph& parent,
18
    DataFlowNode& src,
19
    const std::string& src_conn,
20
    DataFlowNode& dst,
21
    const std::string& dst_conn,
22
    const Subset& subset,
23
    const types::IType& base_type
24
)
25
    : Element(element_id, debug_info), edge_(edge), parent_(&parent), src_(src), dst_(dst), src_conn_(src_conn),
3,085✔
26
      dst_conn_(dst_conn), subset_(subset), base_type_(base_type.clone()) {
3,085✔
27

28
      };
3,085✔
29

30
void Memlet::validate(const Function& function) const {
8,465✔
31
    // Validate subset
32
    for (const auto& dim : this->subset_) {
8,465✔
33
        // Null ptr check
34
        if (dim.is_null()) {
7,655✔
35
            throw InvalidSDFGException("Memlet: Subset dimensions cannot be null");
×
36
        }
×
37
    }
7,655✔
38

39
    // Validate connections
40
    switch (this->type()) {
8,465✔
41
        case MemletType::Computational: {
8,315✔
42
            // Criterion: Must connect a code node and an access node with void connector at access node
43
            const AccessNode* data_node = nullptr;
8,315✔
44
            const CodeNode* code_node = nullptr;
8,315✔
45
            if (this->src_conn_ == "void") {
8,315✔
46
                data_node = dynamic_cast<const AccessNode*>(&this->src_);
5,166✔
47
                code_node = dynamic_cast<const CodeNode*>(&this->dst_);
5,166✔
48
                if (!data_node || !code_node) {
5,166✔
49
                    throw InvalidSDFGException("Memlet: Computation memlets must connect a code node and an access node"
×
50
                    );
×
51
                }
×
52

53
                // Criterion: Non-void connector must be an input of the code node
54
                if (std::find(code_node->inputs().begin(), code_node->inputs().end(), this->dst_conn_) ==
5,166✔
55
                    code_node->inputs().end()) {
5,166✔
56
                    throw InvalidSDFGException("Memlet: Computation memlets must have an input in the code node");
×
57
                }
×
58
            } else if (this->dst_conn_ == "void") {
5,166✔
59
                data_node = dynamic_cast<const AccessNode*>(&this->dst_);
3,149✔
60
                code_node = dynamic_cast<const CodeNode*>(&this->src_);
3,149✔
61
                if (!data_node || !code_node) {
3,149✔
62
                    throw InvalidSDFGException("Memlet: Computation memlets must connect a code node and an access node"
×
63
                    );
×
64
                }
×
65

66
                // Criterion: Non-void connector must be an output of the code node
67
                if (std::find(code_node->outputs().begin(), code_node->outputs().end(), this->src_conn_) ==
3,149✔
68
                    code_node->outputs().end()) {
3,149✔
69
                    throw InvalidSDFGException("Memlet: Computation memlets must have an output in the code node");
×
70
                }
×
71
            } else {
3,149✔
72
                throw InvalidSDFGException(
×
73
                    "Memlet: Computation memlets must have void connector at source or destination"
×
74
                );
×
75
            }
×
76

77
            // If tensor, check that the type is consistenly defined
78
            if (this->base_type_->type_id() == types::TypeID::Tensor) {
8,315✔
79
                auto& tensor_type = dynamic_cast<const types::Tensor&>(*this->base_type_);
907✔
80
                if (tensor_type.is_scalar()) {
907✔
81
                    if (auto const_node = dynamic_cast<const data_flow::ConstantNode*>(data_node)) {
78✔
82
                        if (const_node->type().type_id() != types::TypeID::Scalar) {
30✔
NEW
83
                            throw InvalidSDFGException(
×
NEW
84
                                "Memlet: Scalar tensors must reference scalar buffers. Base type: " +
×
NEW
85
                                this->base_type_->print() + " Buffer type: " + const_node->type().print()
×
NEW
86
                            );
×
NEW
87
                        }
×
88
                    } else {
48✔
89
                        auto& buffer_type = function.type(data_node->data());
48✔
90
                        if (buffer_type.type_id() != types::TypeID::Scalar) {
48✔
NEW
91
                            throw InvalidSDFGException(
×
NEW
92
                                "Memlet: Scalar tensors must reference scalar buffers. Base type: " +
×
NEW
93
                                this->base_type_->print() + " Buffer type: " + buffer_type.print()
×
NEW
94
                            );
×
NEW
95
                        }
×
96
                    }
48✔
97
                } else {
829✔
98
                    auto& buffer_type = function.type(data_node->data());
829✔
99
                    if (buffer_type.type_id() != types::TypeID::Pointer) {
829✔
NEW
100
                        throw InvalidSDFGException(
×
NEW
101
                            "Memlet: Non-scalar tensors must reference pointer buffers. Base type: " +
×
NEW
102
                            this->base_type_->print() + " Buffer type: " + buffer_type.print()
×
NEW
103
                        );
×
NEW
104
                    }
×
105
                    if (this->subset_.size() > tensor_type.shape().size()) {
829✔
NEW
106
                        throw InvalidSDFGException(
×
NEW
107
                            "Memlet: Subset dimensions must match base type dimensions. Base type: " +
×
NEW
108
                            this->base_type_->print() + " Subset Dim: " + std::to_string(this->subset_.size())
×
NEW
109
                        );
×
NEW
110
                    }
×
111
                    if (tensor_type.shape().size() != tensor_type.strides().size()) {
829✔
NEW
112
                        throw InvalidSDFGException(
×
NEW
113
                            "Memlet: Tensor types must have the same number of shape and stride dimensions. Base "
×
NEW
114
                            "type: " +
×
NEW
115
                            this->base_type_->print()
×
NEW
116
                        );
×
NEW
117
                    }
×
118
                }
829✔
119
            }
907✔
120
            break;
8,315✔
121
        }
8,315✔
122
        case MemletType::Reference: {
8,315✔
123
            // Criterion: Destination must be an access node with a pointer type
124
            auto dst_node = dynamic_cast<const AccessNode*>(&this->dst_);
111✔
125
            if (!dst_node) {
111✔
126
                throw InvalidSDFGException("Memlet: Reference memlets must have an access node destination");
×
127
            }
×
128
            auto dst_data = dst_node->data();
111✔
129
            // Criterion: Destination must be non-constant
130
            if (helpers::is_number(dst_data) || symbolic::is_nullptr(symbolic::symbol(dst_data))) {
111✔
131
                throw InvalidSDFGException("Memlet: Reference memlets must have a non-constant destination");
×
132
            }
×
133

134
            // Criterion: Destination must be a pointer
135
            auto& dst_type = function.type(dst_data);
111✔
136
            if (dst_type.type_id() != types::TypeID::Pointer) {
111✔
137
                throw InvalidSDFGException("Memlet: Reference memlets must have a pointer destination");
×
138
            }
×
139

140
            // Criterion: Source must be an access node
141
            if (this->src_conn_ != "void") {
111✔
142
                throw InvalidSDFGException("Memlet: Reference memlets must have a void source");
×
143
            }
×
144
            auto src_node = dynamic_cast<const AccessNode*>(&this->src_);
111✔
145
            if (!src_node) {
111✔
146
                throw InvalidSDFGException("Memlet: Reference memlets must have an access node source");
×
147
            }
×
148

149
            // Case: Constant
150
            if (helpers::is_number(src_node->data()) || symbolic::is_nullptr(symbolic::symbol(src_node->data()))) {
111✔
151
                if (!this->subset_.empty()) {
4✔
152
                    throw InvalidSDFGException("Memlet: Reference memlets for raw addresses must not have a subset");
×
153
                }
×
154
                return;
4✔
155
            }
4✔
156

157
            // Case: Container
158
            // Criterion: Must be contiguous memory reference
159
            // Throws exception if not contiguous
160
            types::infer_type(function, *this->base_type_, this->subset_);
107✔
161
            break;
107✔
162
        }
111✔
163
        case MemletType::Dereference_Src: {
27✔
164
            if (this->src_conn_ != "void") {
27✔
165
                throw InvalidSDFGException("Memlet: Dereference memlets must have a void destination");
×
166
            }
×
167

168
            auto src_node = dynamic_cast<const AccessNode*>(&this->src_);
27✔
169
            if (!src_node) {
27✔
170
                throw InvalidSDFGException("Memlet: Dereference memlets must have an access node source");
×
171
            }
×
172
            auto dst_node = dynamic_cast<const AccessNode*>(&this->dst_);
27✔
173
            if (!dst_node) {
27✔
174
                throw InvalidSDFGException("Memlet: Dereference memlets must have an access node destination");
×
175
            }
×
176

177
            // Criterion: Dereference memlets must have '0' as the only dimension
178
            if (this->subset_.size() != 1) {
27✔
179
                throw InvalidSDFGException("Memlet: Dereference memlets must have '0' as the only dimension");
×
180
            }
×
181
            if (!symbolic::eq(this->subset_[0], symbolic::zero())) {
27✔
182
                throw InvalidSDFGException("Memlet: Dereference memlets must have '0' as the only dimension");
×
183
            }
×
184

185
            // Criterion: Source must be a pointer
186
            if (auto const_node = dynamic_cast<const ConstantNode*>(src_node)) {
27✔
187
                if (const_node->type().type_id() != types::TypeID::Pointer &&
×
188
                    const_node->type().type_id() != types::TypeID::Scalar) {
×
189
                    throw InvalidSDFGException("Memlet: Dereference memlets must have a pointer source");
×
190
                }
×
191
            } else {
27✔
192
                auto src_data = src_node->data();
27✔
193
                auto& src_type = function.type(src_data);
27✔
194
                if (src_type.type_id() != types::TypeID::Pointer) {
27✔
195
                    throw InvalidSDFGException("Memlet: Dereference memlets must have a pointer source");
×
196
                }
×
197
            }
27✔
198

199
            // Criterion: Must be typed pointer
200
            auto base_pointer_type = dynamic_cast<const types::Pointer*>(this->base_type_.get());
27✔
201
            if (!base_pointer_type) {
27✔
202
                throw InvalidSDFGException("Memlet: Dereference memlets must have a typed pointer base type");
×
203
            }
×
204
            if (!base_pointer_type->has_pointee_type()) {
27✔
205
                throw InvalidSDFGException("Memlet: Dereference memlets must have a pointee type");
×
206
            }
×
207

208
            break;
27✔
209
        }
27✔
210
        case MemletType::Dereference_Dst: {
27✔
211
            if (this->dst_conn_ != "void") {
12✔
212
                throw InvalidSDFGException("Memlet: Dereference memlets must have a void source");
×
213
            }
×
214

215
            auto src_node = dynamic_cast<const AccessNode*>(&this->src_);
12✔
216
            if (!src_node) {
12✔
217
                throw InvalidSDFGException("Memlet: Dereference memlets must have an access node source");
×
218
            }
×
219
            auto dst_node = dynamic_cast<const AccessNode*>(&this->dst_);
12✔
220
            if (!dst_node) {
12✔
221
                throw InvalidSDFGException("Memlet: Dereference memlets must have an access node destination");
×
222
            }
×
223

224
            // Criterion: Dereference memlets must have '0' as the only dimension
225
            if (this->subset_.size() != 1) {
12✔
226
                throw InvalidSDFGException("Memlet: Dereference memlets must have '0' as the only dimension");
×
227
            }
×
228
            if (!symbolic::eq(this->subset_[0], symbolic::zero())) {
12✔
229
                throw InvalidSDFGException("Memlet: Dereference memlets must have '0' as the only dimension");
×
230
            }
×
231

232
            // Criterion: src type cannot be a function
233
            const sdfg::types::IType* src_type;
12✔
234
            if (auto const_node = dynamic_cast<const data_flow::ConstantNode*>(src_node)) {
12✔
235
                src_type = &const_node->type();
2✔
236
            } else {
10✔
237
                src_type = &function.type(src_node->data());
10✔
238
            }
10✔
239
            if (src_type->type_id() == types::TypeID::Function) {
12✔
240
                throw InvalidSDFGException("Memlet: Dereference memlets cannot have source of type Function");
×
241
            }
×
242

243
            // Criterion: Destination must be a pointer
244
            if (auto const_node = dynamic_cast<const ConstantNode*>(dst_node)) {
12✔
245
                throw InvalidSDFGException("Memlet: Dereference memlets must have a non-constant destination");
×
246
            }
×
247
            auto dst_data = dst_node->data();
12✔
248
            auto& dst_type = function.type(dst_data);
12✔
249
            if (dst_type.type_id() != types::TypeID::Pointer) {
12✔
250
                throw InvalidSDFGException("Memlet: Dereference memlets must have a pointer destination");
×
251
            }
×
252

253
            // Criterion: Must be typed pointer
254
            auto base_pointer_type = dynamic_cast<const types::Pointer*>(this->base_type_.get());
12✔
255
            if (!base_pointer_type) {
12✔
256
                throw InvalidSDFGException("Memlet: Dereference memlets must have a typed pointer base type");
×
257
            }
×
258
            if (!base_pointer_type->has_pointee_type()) {
12✔
259
                throw InvalidSDFGException("Memlet: Dereference memlets must have a pointee type");
×
260
            }
×
261

262
            break;
12✔
263
        }
12✔
264
        default:
12✔
265
            throw InvalidSDFGException("Memlet: Invalid memlet type");
×
266
    }
8,465✔
267
};
8,465✔
268

269
const graph::Edge Memlet::edge() const { return this->edge_; };
406✔
270

271
const DataFlowGraph& Memlet::get_parent() const { return *this->parent_; };
×
272

273
DataFlowGraph& Memlet::get_parent() { return *this->parent_; };
51✔
274

275
MemletType Memlet::type() const {
17,047✔
276
    if (this->dst_conn_ == "ref") {
17,047✔
277
        return Reference;
222✔
278
    } else if (this->dst_conn_ == "deref") {
16,825✔
279
        return Dereference_Src;
97✔
280
    } else if (this->src_conn_ == "deref") {
16,728✔
281
        return Dereference_Dst;
46✔
282
    } else {
16,682✔
283
        return Computational;
16,682✔
284
    }
16,682✔
285
}
17,047✔
286

287
const DataFlowNode& Memlet::src() const { return this->src_; };
1,031✔
288

289
DataFlowNode& Memlet::src() { return this->src_; };
5,564✔
290

291
const DataFlowNode& Memlet::dst() const { return this->dst_; };
741✔
292

293
DataFlowNode& Memlet::dst() { return this->dst_; };
1,011✔
294

295
const std::string& Memlet::src_conn() const { return this->src_conn_; };
990✔
296

297
const std::string& Memlet::dst_conn() const { return this->dst_conn_; };
4,774✔
298

299
const Subset& Memlet::subset() const { return this->subset_; };
11,262✔
300

301
void Memlet::set_subset(const Subset& subset) { this->subset_ = subset; };
55✔
302

303
const types::IType& Memlet::base_type() const { return *this->base_type_; };
3,425✔
304

305
void Memlet::set_base_type(const types::IType& base_type) { this->base_type_ = base_type.clone(); };
44✔
306

307
std::unique_ptr<types::IType> Memlet::result_type(const Function& function) const {
3,502✔
308
    return types::infer_type(function, *this->base_type_, this->subset_);
3,502✔
309
};
3,502✔
310

311
std::unique_ptr<Memlet> Memlet::clone(
312
    size_t element_id, const graph::Edge& edge, DataFlowGraph& parent, DataFlowNode& src, DataFlowNode& dst
313
) const {
×
314
    return std::unique_ptr<Memlet>(new Memlet(
×
315
        element_id,
×
316
        this->debug_info_,
×
317
        edge,
×
318
        parent,
×
319
        src,
×
320
        this->src_conn_,
×
321
        dst,
×
322
        this->dst_conn_,
×
323
        this->subset_,
×
324
        *this->base_type_
×
325
    ));
×
326
};
×
327

328
void Memlet::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
189✔
329
    Subset new_subset;
189✔
330
    for (auto& dim : this->subset_) {
189✔
331
        new_subset.push_back(symbolic::subs(dim, old_expression, new_expression));
182✔
332
    }
182✔
333
    this->subset_ = new_subset;
189✔
334
};
189✔
335

336
} // namespace data_flow
337
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc