• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22949872003

11 Mar 2026 11:15AM UTC coverage: 63.681% (-0.9%) from 64.6%
22949872003

push

github

web-flow
Merge pull request #569 from daisytuner/HIPtarget

ROCmTarget

191 of 803 new or added lines in 15 files covered. (23.79%)

3 existing lines in 2 files now uncovered.

24700 of 38787 relevant lines covered (63.68%)

370.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/sdfg/src/codegen/language_extensions/rocm_language_extension.cpp
1
#include "sdfg/codegen/language_extensions/rocm_language_extension.h"
2

3
#include "sdfg/codegen/language_extensions/cpp_language_extension.h"
4
#include "sdfg/codegen/utils.h"
5
#include "sdfg/data_flow/library_node.h"
6
#include "sdfg/data_flow/tasklet.h"
7

8
namespace sdfg {
9
namespace codegen {
10

NEW
11
std::string ROCMLanguageExtension::primitive_type(const types::PrimitiveType prim_type) {
×
NEW
12
    switch (prim_type) {
×
NEW
13
        case types::PrimitiveType::Void:
×
NEW
14
            return "void";
×
NEW
15
        case types::PrimitiveType::Bool:
×
NEW
16
            return "bool";
×
NEW
17
        case types::PrimitiveType::Int8:
×
NEW
18
            return "signed char";
×
NEW
19
        case types::PrimitiveType::Int16:
×
NEW
20
            return "short";
×
NEW
21
        case types::PrimitiveType::Int32:
×
NEW
22
            return "int";
×
NEW
23
        case types::PrimitiveType::Int64:
×
NEW
24
            return "long long";
×
NEW
25
        case types::PrimitiveType::Int128:
×
NEW
26
            return "__int128";
×
NEW
27
        case types::PrimitiveType::UInt8:
×
NEW
28
            return "char";
×
NEW
29
        case types::PrimitiveType::UInt16:
×
NEW
30
            return "unsigned short";
×
NEW
31
        case types::PrimitiveType::UInt32:
×
NEW
32
            return "unsigned int";
×
NEW
33
        case types::PrimitiveType::UInt64:
×
NEW
34
            return "unsigned long long";
×
NEW
35
        case types::PrimitiveType::UInt128:
×
NEW
36
            return "unsigned __int128";
×
NEW
37
        case types::PrimitiveType::Half:
×
NEW
38
            return "__fp16";
×
NEW
39
        case types::PrimitiveType::BFloat:
×
NEW
40
            return "__bf16";
×
NEW
41
        case types::PrimitiveType::Float:
×
NEW
42
            return "float";
×
NEW
43
        case types::PrimitiveType::Double:
×
NEW
44
            return "double";
×
NEW
45
        case types::PrimitiveType::X86_FP80:
×
NEW
46
            return "long double";
×
NEW
47
        case types::PrimitiveType::FP128:
×
NEW
48
            return "__float128";
×
NEW
49
        case types::PrimitiveType::PPC_FP128:
×
NEW
50
            return "__float128";
×
NEW
51
    }
×
52

NEW
53
    throw std::runtime_error("Unknown primitive type");
×
NEW
54
};
×
55

56
std::string ROCMLanguageExtension::
NEW
57
    declaration(const std::string& name, const types::IType& type, bool use_initializer, bool use_alignment) {
×
NEW
58
    std::stringstream val;
×
59

NEW
60
    if (auto scalar_type = dynamic_cast<const types::Scalar*>(&type)) {
×
NEW
61
        if (scalar_type->storage_type().is_nv_shared()) {
×
NEW
62
            val << "__shared__ ";
×
NEW
63
        } else if (scalar_type->storage_type().is_nv_constant()) {
×
NEW
64
            val << "__constant__ ";
×
NEW
65
        }
×
NEW
66
        val << primitive_type(scalar_type->primitive_type());
×
NEW
67
        val << " ";
×
NEW
68
        val << name;
×
NEW
69
    } else if (auto array_type = dynamic_cast<const types::Array*>(&type)) {
×
NEW
70
        if (array_type->storage_type().is_nv_shared()) {
×
NEW
71
            val << "__shared__ ";
×
NEW
72
        }
×
NEW
73
        auto& element_type = array_type->element_type();
×
NEW
74
        val << declaration(name + "[" + this->expression(array_type->num_elements()) + "]", element_type);
×
NEW
75
    } else if (auto pointer_type = dynamic_cast<const types::Pointer*>(&type)) {
×
NEW
76
        if (pointer_type->has_pointee_type()) {
×
NEW
77
            const types::IType& pointee = pointer_type->pointee_type();
×
78

NEW
79
            const bool pointee_is_function_or_array = dynamic_cast<const types::Function*>(&pointee) ||
×
NEW
80
                                                      dynamic_cast<const types::Array*>(&pointee);
×
81

82
            // Parenthesise *only* when it is needed to bind tighter than [] or ()
NEW
83
            std::string decorated = pointee_is_function_or_array ? "(*" + name + ")" : "*" + name;
×
84

NEW
85
            val << declaration(decorated, pointee);
×
NEW
86
        } else {
×
NEW
87
            val << "void*";
×
NEW
88
            val << " " << name;
×
NEW
89
        }
×
NEW
90
    } else if (auto ref_type = dynamic_cast<const Reference*>(&type)) {
×
NEW
91
        val << declaration("&" + name, ref_type->reference_type());
×
NEW
92
    } else if (auto structure_type = dynamic_cast<const types::Structure*>(&type)) {
×
NEW
93
        if (structure_type->storage_type().is_nv_shared()) {
×
NEW
94
            val << "__shared__ ";
×
NEW
95
        } else if (structure_type->storage_type().is_nv_constant()) {
×
NEW
96
            val << "__constant__ ";
×
NEW
97
        }
×
NEW
98
        val << structure_type->name();
×
NEW
99
        val << " ";
×
NEW
100
        val << name;
×
NEW
101
    } else if (auto function_type = dynamic_cast<const types::Function*>(&type)) {
×
NEW
102
        std::stringstream params;
×
NEW
103
        for (size_t i = 0; i < function_type->num_params(); ++i) {
×
NEW
104
            params << declaration("", function_type->param_type(symbolic::integer(i)));
×
NEW
105
            if (i + 1 < function_type->num_params()) params << ", ";
×
NEW
106
        }
×
NEW
107
        if (function_type->is_var_arg()) {
×
NEW
108
            if (function_type->num_params() > 0) {
×
NEW
109
                params << ", ";
×
NEW
110
            }
×
NEW
111
            params << "...";
×
NEW
112
        }
×
113

NEW
114
        const std::string fun_name = name + "(" + params.str() + ")";
×
NEW
115
        val << declaration(fun_name, function_type->return_type());
×
NEW
116
    } else {
×
NEW
117
        throw std::runtime_error("Unknown declaration type");
×
NEW
118
    }
×
119

NEW
120
    if (use_alignment && type.alignment() > 0) {
×
NEW
121
        val << " __attribute__((aligned(" << type.alignment() << ")))";
×
NEW
122
    }
×
123

NEW
124
    if (use_initializer && !type.initializer().empty()) {
×
NEW
125
        val << " = " << type.initializer();
×
NEW
126
    }
×
127

NEW
128
    return val.str();
×
NEW
129
};
×
130

NEW
131
std::string ROCMLanguageExtension::type_cast(const std::string& name, const types::IType& type) {
×
NEW
132
    std::stringstream val;
×
133

NEW
134
    val << "reinterpret_cast";
×
NEW
135
    val << "<";
×
NEW
136
    val << declaration("", type);
×
NEW
137
    val << ">";
×
NEW
138
    val << "(" << name << ")";
×
139

NEW
140
    return val.str();
×
NEW
141
};
×
142

NEW
143
std::string ROCMLanguageExtension::subset(const types::IType& type, const data_flow::Subset& sub) {
×
NEW
144
    if (sub.empty()) {
×
NEW
145
        return "";
×
NEW
146
    }
×
147

NEW
148
    if (dynamic_cast<const types::Scalar*>(&type)) {
×
NEW
149
        return "";
×
NEW
150
    } else if (auto array_type = dynamic_cast<const types::Array*>(&type)) {
×
NEW
151
        std::string subset_str = "[" + this->expression(sub.at(0)) + "]";
×
152

NEW
153
        if (sub.size() > 1) {
×
NEW
154
            data_flow::Subset element_subset(sub.begin() + 1, sub.end());
×
NEW
155
            auto& element_type = array_type->element_type();
×
NEW
156
            return subset_str + subset(element_type, element_subset);
×
NEW
157
        } else {
×
NEW
158
            return subset_str;
×
NEW
159
        }
×
NEW
160
    } else if (auto pointer_type = dynamic_cast<const types::Pointer*>(&type)) {
×
NEW
161
        std::string subset_str = "[" + this->expression(sub.at(0)) + "]";
×
162

NEW
163
        data_flow::Subset element_subset(sub.begin() + 1, sub.end());
×
NEW
164
        auto& pointee_type = pointer_type->pointee_type();
×
NEW
165
        return subset_str + subset(pointee_type, element_subset);
×
NEW
166
    } else if (auto structure_type = dynamic_cast<const types::Structure*>(&type)) {
×
NEW
167
        auto& definition = this->function_.structure(structure_type->name());
×
168

NEW
169
        std::string subset_str = ".member_" + this->expression(sub.at(0));
×
NEW
170
        if (sub.size() > 1) {
×
NEW
171
            auto member = SymEngine::rcp_dynamic_cast<const SymEngine::Integer>(sub.at(0));
×
NEW
172
            auto& member_type = definition.member_type(member);
×
NEW
173
            data_flow::Subset element_subset(sub.begin() + 1, sub.end());
×
NEW
174
            return subset_str + subset(member_type, element_subset);
×
NEW
175
        } else {
×
NEW
176
            return subset_str;
×
NEW
177
        }
×
NEW
178
    }
×
179

NEW
180
    throw std::invalid_argument("Invalid subset type");
×
NEW
181
};
×
182

NEW
183
std::string ROCMLanguageExtension::expression(const symbolic::Expression expr) {
×
NEW
184
    CPPSymbolicPrinter printer(this->function_, this->external_prefix_);
×
NEW
185
    return printer.apply(expr);
×
NEW
186
};
×
187

NEW
188
std::string ROCMLanguageExtension::access_node(const data_flow::AccessNode& node) {
×
NEW
189
    if (dynamic_cast<const data_flow::ConstantNode*>(&node)) {
×
NEW
190
        std::string name = node.data();
×
NEW
191
        if (symbolic::is_nullptr(symbolic::symbol(name))) {
×
NEW
192
            return "nullptr";
×
NEW
193
        }
×
NEW
194
        return name;
×
NEW
195
    } else {
×
NEW
196
        std::string name = node.data();
×
NEW
197
        if (this->function_.is_external(name)) {
×
NEW
198
            return "(&" + name + ")";
×
NEW
199
        }
×
NEW
200
        return name;
×
NEW
201
    }
×
NEW
202
};
×
203

NEW
204
std::string ROCMLanguageExtension::tasklet(const data_flow::Tasklet& tasklet) {
×
NEW
205
    switch (tasklet.code()) {
×
NEW
206
        case data_flow::TaskletCode::assign:
×
NEW
207
            return tasklet.inputs().at(0);
×
NEW
208
        case data_flow::TaskletCode::fp_neg:
×
NEW
209
            return "-" + tasklet.inputs().at(0);
×
NEW
210
        case data_flow::TaskletCode::fp_add:
×
NEW
211
            return tasklet.inputs().at(0) + " + " + tasklet.inputs().at(1);
×
NEW
212
        case data_flow::TaskletCode::fp_sub:
×
NEW
213
            return tasklet.inputs().at(0) + " - " + tasklet.inputs().at(1);
×
NEW
214
        case data_flow::TaskletCode::fp_mul:
×
NEW
215
            return tasklet.inputs().at(0) + " * " + tasklet.inputs().at(1);
×
NEW
216
        case data_flow::TaskletCode::fp_div:
×
NEW
217
            return tasklet.inputs().at(0) + " / " + tasklet.inputs().at(1);
×
NEW
218
        case data_flow::TaskletCode::fp_rem:
×
NEW
219
            return "fmod(" + tasklet.inputs().at(0) + ", " + tasklet.inputs().at(1) + ")";
×
NEW
220
        case data_flow::TaskletCode::fp_fma:
×
NEW
221
            return tasklet.inputs().at(0) + " * " + tasklet.inputs().at(1) + " + " + tasklet.inputs().at(2);
×
NEW
222
        case data_flow::TaskletCode::fp_oeq:
×
NEW
223
            return tasklet.inputs().at(0) + " == " + tasklet.inputs().at(1);
×
NEW
224
        case data_flow::TaskletCode::fp_one:
×
NEW
225
            return tasklet.inputs().at(0) + " != " + tasklet.inputs().at(1);
×
NEW
226
        case data_flow::TaskletCode::fp_ogt:
×
NEW
227
            return tasklet.inputs().at(0) + " > " + tasklet.inputs().at(1);
×
NEW
228
        case data_flow::TaskletCode::fp_oge:
×
NEW
229
            return tasklet.inputs().at(0) + " >= " + tasklet.inputs().at(1);
×
NEW
230
        case data_flow::TaskletCode::fp_olt:
×
NEW
231
            return tasklet.inputs().at(0) + " < " + tasklet.inputs().at(1);
×
NEW
232
        case data_flow::TaskletCode::fp_ole:
×
NEW
233
            return tasklet.inputs().at(0) + " <= " + tasklet.inputs().at(1);
×
NEW
234
        case data_flow::TaskletCode::fp_ord:
×
NEW
235
            return "std::isnan(" + tasklet.inputs().at(0) + ") && std::isnan(" + tasklet.inputs().at(1) + ")";
×
NEW
236
        case data_flow::TaskletCode::fp_ueq:
×
NEW
237
            return "std::isnan(" + tasklet.inputs().at(0) + ") || std::isnan(" + tasklet.inputs().at(1) + ")" + " || " +
×
NEW
238
                   tasklet.inputs().at(0) + " == " + tasklet.inputs().at(1);
×
NEW
239
        case data_flow::TaskletCode::fp_une:
×
NEW
240
            return "std::isnan(" + tasklet.inputs().at(0) + ") || std::isnan(" + tasklet.inputs().at(1) + ")" + " || " +
×
NEW
241
                   tasklet.inputs().at(0) + " != " + tasklet.inputs().at(1);
×
NEW
242
        case data_flow::TaskletCode::fp_ugt:
×
NEW
243
            return "std::isnan(" + tasklet.inputs().at(0) + ") || std::isnan(" + tasklet.inputs().at(1) + ")" + " || " +
×
NEW
244
                   tasklet.inputs().at(0) + " > " + tasklet.inputs().at(1);
×
NEW
245
        case data_flow::TaskletCode::fp_uge:
×
NEW
246
            return "std::isnan(" + tasklet.inputs().at(0) + ") || std::isnan(" + tasklet.inputs().at(1) + ")" + " || " +
×
NEW
247
                   tasklet.inputs().at(0) + " >= " + tasklet.inputs().at(1);
×
NEW
248
        case data_flow::TaskletCode::fp_ult:
×
NEW
249
            return "std::isnan(" + tasklet.inputs().at(0) + ") || std::isnan(" + tasklet.inputs().at(1) + ")" + " || " +
×
NEW
250
                   tasklet.inputs().at(0) + " < " + tasklet.inputs().at(1);
×
NEW
251
        case data_flow::TaskletCode::fp_ule:
×
NEW
252
            return "std::isnan(" + tasklet.inputs().at(0) + ") || std::isnan(" + tasklet.inputs().at(1) + ")" + " || " +
×
NEW
253
                   tasklet.inputs().at(0) + " <= " + tasklet.inputs().at(1);
×
NEW
254
        case data_flow::TaskletCode::fp_uno:
×
NEW
255
            return "std::isnan(" + tasklet.inputs().at(0) + ") || std::isnan(" + tasklet.inputs().at(1) + ")";
×
NEW
256
        case data_flow::TaskletCode::int_add:
×
NEW
257
            return tasklet.inputs().at(0) + " + " + tasklet.inputs().at(1);
×
NEW
258
        case data_flow::TaskletCode::int_sub:
×
NEW
259
            return tasklet.inputs().at(0) + " - " + tasklet.inputs().at(1);
×
NEW
260
        case data_flow::TaskletCode::int_mul:
×
NEW
261
            return tasklet.inputs().at(0) + " * " + tasklet.inputs().at(1);
×
NEW
262
        case data_flow::TaskletCode::int_sdiv:
×
NEW
263
            return tasklet.inputs().at(0) + " / " + tasklet.inputs().at(1);
×
NEW
264
        case data_flow::TaskletCode::int_srem:
×
NEW
265
            return tasklet.inputs().at(0) + " % " + tasklet.inputs().at(1);
×
NEW
266
        case data_flow::TaskletCode::int_udiv:
×
NEW
267
            return tasklet.inputs().at(0) + " / " + tasklet.inputs().at(1);
×
NEW
268
        case data_flow::TaskletCode::int_urem:
×
NEW
269
            return tasklet.inputs().at(0) + " % " + tasklet.inputs().at(1);
×
NEW
270
        case data_flow::TaskletCode::int_and:
×
NEW
271
            return tasklet.inputs().at(0) + " & " + tasklet.inputs().at(1);
×
NEW
272
        case data_flow::TaskletCode::int_or:
×
NEW
273
            return tasklet.inputs().at(0) + " | " + tasklet.inputs().at(1);
×
NEW
274
        case data_flow::TaskletCode::int_xor:
×
NEW
275
            return tasklet.inputs().at(0) + " ^ " + tasklet.inputs().at(1);
×
NEW
276
        case data_flow::TaskletCode::int_shl:
×
NEW
277
            return tasklet.inputs().at(0) + " << " + tasklet.inputs().at(1);
×
NEW
278
        case data_flow::TaskletCode::int_lshr:
×
NEW
279
            return tasklet.inputs().at(0) + " >> " + tasklet.inputs().at(1);
×
NEW
280
        case data_flow::TaskletCode::int_ashr:
×
NEW
281
            return tasklet.inputs().at(0) + " >> " + tasklet.inputs().at(1);
×
NEW
282
        case data_flow::TaskletCode::int_smin:
×
NEW
283
            return tasklet.inputs().at(0) + " < " + tasklet.inputs().at(1) + " ? " + tasklet.inputs().at(0) + " : " +
×
NEW
284
                   tasklet.inputs().at(1);
×
NEW
285
        case data_flow::TaskletCode::int_smax:
×
NEW
286
            return tasklet.inputs().at(0) + " > " + tasklet.inputs().at(1) + " ? " + tasklet.inputs().at(0) + " : " +
×
NEW
287
                   tasklet.inputs().at(1);
×
NEW
288
        case data_flow::TaskletCode::int_scmp:
×
NEW
289
            return tasklet.inputs().at(0) + " < " + tasklet.inputs().at(1) + " ? -1 : (" + tasklet.inputs().at(0) +
×
NEW
290
                   " > " + tasklet.inputs().at(1) + " ? 1 : 0)";
×
NEW
291
        case data_flow::TaskletCode::int_umin:
×
NEW
292
            return tasklet.inputs().at(0) + " < " + tasklet.inputs().at(1) + " ? " + tasklet.inputs().at(0) + " : " +
×
NEW
293
                   tasklet.inputs().at(1);
×
NEW
294
        case data_flow::TaskletCode::int_umax:
×
NEW
295
            return tasklet.inputs().at(0) + " > " + tasklet.inputs().at(1) + " ? " + tasklet.inputs().at(0) + " : " +
×
NEW
296
                   tasklet.inputs().at(1);
×
NEW
297
        case data_flow::TaskletCode::int_ucmp:
×
NEW
298
            return tasklet.inputs().at(0) + " < " + tasklet.inputs().at(1) + " ? -1 : (" + tasklet.inputs().at(0) +
×
NEW
299
                   " > " + tasklet.inputs().at(1) + " ? 1 : 0)";
×
NEW
300
        case data_flow::TaskletCode::int_abs:
×
NEW
301
            return "(" + tasklet.inputs().at(0) + " < 0 ? -" + tasklet.inputs().at(0) + " : " + tasklet.inputs().at(0) +
×
NEW
302
                   ")";
×
NEW
303
        case data_flow::TaskletCode::int_eq:
×
NEW
304
            return tasklet.inputs().at(0) + " == " + tasklet.inputs().at(1);
×
NEW
305
        case data_flow::TaskletCode::int_ne:
×
NEW
306
            return tasklet.inputs().at(0) + " != " + tasklet.inputs().at(1);
×
NEW
307
        case data_flow::TaskletCode::int_sgt:
×
NEW
308
            return tasklet.inputs().at(0) + " > " + tasklet.inputs().at(1);
×
NEW
309
        case data_flow::TaskletCode::int_sge:
×
NEW
310
            return tasklet.inputs().at(0) + " >= " + tasklet.inputs().at(1);
×
NEW
311
        case data_flow::TaskletCode::int_slt:
×
NEW
312
            return tasklet.inputs().at(0) + " < " + tasklet.inputs().at(1);
×
NEW
313
        case data_flow::TaskletCode::int_sle:
×
NEW
314
            return tasklet.inputs().at(0) + " <= " + tasklet.inputs().at(1);
×
NEW
315
        case data_flow::TaskletCode::int_ugt:
×
NEW
316
            return tasklet.inputs().at(0) + " > " + tasklet.inputs().at(1);
×
NEW
317
        case data_flow::TaskletCode::int_uge:
×
NEW
318
            return tasklet.inputs().at(0) + " >= " + tasklet.inputs().at(1);
×
NEW
319
        case data_flow::TaskletCode::int_ult:
×
NEW
320
            return tasklet.inputs().at(0) + " < " + tasklet.inputs().at(1);
×
NEW
321
        case data_flow::TaskletCode::int_ule:
×
NEW
322
            return tasklet.inputs().at(0) + " <= " + tasklet.inputs().at(1);
×
NEW
323
    };
×
NEW
324
    throw std::invalid_argument("Invalid tasklet code");
×
NEW
325
};
×
326

NEW
327
std::string ROCMLanguageExtension::zero(const types::PrimitiveType prim_type) {
×
NEW
328
    switch (prim_type) {
×
NEW
329
        case types::Void:
×
NEW
330
            throw InvalidSDFGException("No zero for void type possible");
×
NEW
331
        case types::Bool:
×
NEW
332
            return "false";
×
NEW
333
        case types::Int8:
×
NEW
334
            return "0";
×
NEW
335
        case types::Int16:
×
NEW
336
            return "0";
×
NEW
337
        case types::Int32:
×
NEW
338
            return "0";
×
NEW
339
        case types::Int64:
×
NEW
340
            return "0ll";
×
NEW
341
        case types::Int128:
×
NEW
342
            return "0";
×
NEW
343
        case types::UInt8:
×
NEW
344
            return "0u";
×
NEW
345
        case types::UInt16:
×
NEW
346
            return "0u";
×
NEW
347
        case types::UInt32:
×
NEW
348
            return "0u";
×
NEW
349
        case types::UInt64:
×
NEW
350
            return "0ull";
×
NEW
351
        case types::UInt128:
×
NEW
352
            return "0";
×
NEW
353
        case types::Half:
×
NEW
354
            return "(__fp16)0.0f";
×
NEW
355
        case types::BFloat:
×
NEW
356
            return "(__bf16)0.0f";
×
NEW
357
        case types::Float:
×
NEW
358
            return "0.0f";
×
NEW
359
        case types::Double:
×
NEW
360
            return "0.0";
×
NEW
361
        case types::X86_FP80:
×
NEW
362
            return "0.0l";
×
NEW
363
        case types::FP128:
×
NEW
364
            return "0.0";
×
NEW
365
        case types::PPC_FP128:
×
NEW
366
            return "0.0";
×
NEW
367
    }
×
NEW
368
}
×
369

370
} // namespace codegen
371
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc