• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 28106147644

24 Jun 2026 02:32PM UTC coverage: 61.922% (+0.1%) from 61.779%
28106147644

Pull #806

github

web-flow
Merge 2be414d54 into 57cc1db99
Pull Request #806: Map Collapse for Multiple targets in a neste sequence

165 of 185 new or added lines in 2 files covered. (89.19%)

419 existing lines in 30 files now uncovered.

37705 of 60891 relevant lines covered (61.92%)

1004.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.27
/opt/src/transformations/offloading/rocblas_data_transfer_extraction.cpp
1
#include "sdfg/transformations/offloading/rocblas_data_transfer_extraction.h"
2

3
#include <cassert>
4
#include <cstddef>
5
#include <nlohmann/json_fwd.hpp>
6
#include <string>
7
#include <unordered_map>
8

9
#include "sdfg/analysis/analysis.h"
10
#include "sdfg/builder/structured_sdfg_builder.h"
11
#include "sdfg/data_flow/access_node.h"
12
#include "sdfg/data_flow/library_nodes/math/blas/dot_node.h"
13
#include "sdfg/data_flow/library_nodes/math/blas/gemm_node.h"
14
#include "sdfg/data_flow/library_nodes/math/math.h"
15
#include "sdfg/element.h"
16
#include "sdfg/exceptions.h"
17
#include "sdfg/structured_control_flow/block.h"
18
#include "sdfg/structured_control_flow/sequence.h"
19
#include "sdfg/symbolic/symbolic.h"
20
#include "sdfg/targets/rocm/rocm.h"
21
#include "sdfg/targets/rocm/rocm_data_offloading_node.h"
22
#include "sdfg/transformations/transformation.h"
23
#include "sdfg/types/type.h"
24
#include "sdfg/types/utils.h"
25
#include "symengine/symengine_rcp.h"
26

27
namespace sdfg {
28
namespace rocm {
29

30
std::string ROCBLASDataTransferExtraction::create_device_container(
31
    builder::StructuredSDFGBuilder& builder, const types::Pointer& type, const symbolic::Expression& size
32
) {
5✔
33
    auto new_type = type.clone();
5✔
34
    new_type->storage_type(types::StorageType(
5✔
35
        "AMD_Generic", size, types::StorageType::AllocationType::Unmanaged, types::StorageType::AllocationType::Unmanaged
5✔
36
    ));
5✔
37
    auto device_container = builder.find_new_name(ROCM_DEVICE_PREFIX);
5✔
38
    builder.add_container(device_container, *new_type);
5✔
39
    return device_container;
5✔
40
}
5✔
41

42
void ROCBLASDataTransferExtraction::create_allocate(
43
    builder::StructuredSDFGBuilder& builder,
44
    structured_control_flow::Sequence& sequence,
45
    structured_control_flow::Block& block,
46
    const std::string& device_container,
47
    const symbolic::Expression& size,
48
    const types::Pointer& type
49
) {
×
50
    auto& alloc_block = builder.add_block_before(sequence, block, {}, block.debug_info());
×
51
    offloading::add_offloading_node<ROCMDataOffloadingNode>(
×
52
        builder,
×
53
        alloc_block,
×
54
        device_container,
×
55
        device_container,
×
56
        offloading::DataTransferDirection::NONE,
×
57
        offloading::BufferLifecycle::ALLOC,
×
58
        type,
×
59
        type,
×
60
        this->blas_node_.debug_info(),
×
61
        size,
×
62
        symbolic::zero()
×
63
    );
×
64
}
×
65

66
void ROCBLASDataTransferExtraction::create_deallocate(
67
    builder::StructuredSDFGBuilder& builder,
68
    structured_control_flow::Sequence& sequence,
69
    structured_control_flow::Block& block,
70
    const std::string& device_container,
71
    const types::Pointer& type
72
) {
4✔
73
    auto& dealloc_block = builder.add_block_after(sequence, block, {}, block.debug_info());
4✔
74
    offloading::add_offloading_node<ROCMDataOffloadingNode>(
4✔
75
        builder,
4✔
76
        dealloc_block,
4✔
77
        device_container,
4✔
78
        device_container,
4✔
79
        offloading::DataTransferDirection::NONE,
4✔
80
        offloading::BufferLifecycle::FREE,
4✔
81
        type,
4✔
82
        type,
4✔
83
        this->blas_node_.debug_info(),
4✔
84
        SymEngine::null,
4✔
85
        symbolic::zero()
4✔
86
    );
4✔
87
}
4✔
88

89
void ROCBLASDataTransferExtraction::create_copy_to_device(
90
    builder::StructuredSDFGBuilder& builder,
91
    structured_control_flow::Sequence& sequence,
92
    structured_control_flow::Block& block,
93
    const std::string& host_container,
94
    const std::string& device_container,
95
    const symbolic::Expression& size,
96
    const types::Pointer& type
97
) {
×
98
    auto& copy_block = builder.add_block_before(sequence, block, {}, block.debug_info());
×
99
    offloading::add_offloading_node<ROCMDataOffloadingNode>(
×
100
        builder,
×
101
        copy_block,
×
102
        host_container,
×
103
        device_container,
×
104
        offloading::DataTransferDirection::H2D,
×
105
        offloading::BufferLifecycle::NO_CHANGE,
×
106
        type,
×
107
        type,
×
108
        this->blas_node_.debug_info(),
×
109
        size,
×
110
        symbolic::zero()
×
111
    );
×
112
}
×
113

114
void ROCBLASDataTransferExtraction::create_copy_from_device(
115
    builder::StructuredSDFGBuilder& builder,
116
    structured_control_flow::Sequence& sequence,
117
    structured_control_flow::Block& block,
118
    const std::string& host_container,
119
    const std::string& device_container,
120
    const symbolic::Expression& size,
121
    const types::Pointer& type
122
) {
×
123
    auto& copy_block = builder.add_block_after(sequence, block, {}, block.debug_info());
×
124
    offloading::add_offloading_node<ROCMDataOffloadingNode>(
×
125
        builder,
×
126
        copy_block,
×
127
        host_container,
×
128
        device_container,
×
129
        offloading::DataTransferDirection::D2H,
×
130
        offloading::BufferLifecycle::NO_CHANGE,
×
131
        type,
×
132
        type,
×
133
        this->blas_node_.debug_info(),
×
134
        size,
×
135
        symbolic::zero()
×
136
    );
×
137
}
×
138

139
void ROCBLASDataTransferExtraction::create_copy_to_device_with_allocation(
140
    builder::StructuredSDFGBuilder& builder,
141
    structured_control_flow::Sequence& sequence,
142
    structured_control_flow::Block& block,
143
    const std::string& host_container,
144
    const std::string& device_container,
145
    const symbolic::Expression& size,
146
    const types::Pointer& type
147
) {
5✔
148
    auto& copy_block = builder.add_block_before(sequence, block, {}, block.debug_info());
5✔
149
    offloading::add_offloading_node<ROCMDataOffloadingNode>(
5✔
150
        builder,
5✔
151
        copy_block,
5✔
152
        host_container,
5✔
153
        device_container,
5✔
154
        offloading::DataTransferDirection::H2D,
5✔
155
        offloading::BufferLifecycle::ALLOC,
5✔
156
        type,
5✔
157
        type,
5✔
158
        this->blas_node_.debug_info(),
5✔
159
        size,
5✔
160
        symbolic::zero()
5✔
161
    );
5✔
162
}
5✔
163

164
void ROCBLASDataTransferExtraction::create_copy_from_device_with_deallocation(
165
    builder::StructuredSDFGBuilder& builder,
166
    structured_control_flow::Sequence& sequence,
167
    structured_control_flow::Block& block,
168
    const std::string& host_container,
169
    const std::string& device_container,
170
    const symbolic::Expression& size,
171
    const types::Pointer& type
172
) {
1✔
173
    auto& copy_block = builder.add_block_after(sequence, block, {}, block.debug_info());
1✔
174
    offloading::add_offloading_node<ROCMDataOffloadingNode>(
1✔
175
        builder,
1✔
176
        copy_block,
1✔
177
        host_container,
1✔
178
        device_container,
1✔
179
        offloading::DataTransferDirection::D2H,
1✔
180
        offloading::BufferLifecycle::FREE,
1✔
181
        type,
1✔
182
        type,
1✔
183
        this->blas_node_.debug_info(),
1✔
184
        size,
1✔
185
        symbolic::zero()
1✔
186
    );
1✔
187
}
1✔
188

189
ROCBLASDataTransferExtraction::ROCBLASDataTransferExtraction(math::blas::BLASNode& blas_node) : blas_node_(blas_node) {}
10✔
190

191
std::string ROCBLASDataTransferExtraction::name() const { return "ROCBLASDataTransferExtraction"; }
4✔
192

193
bool ROCBLASDataTransferExtraction::
194
    can_be_applied(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
6✔
195
    // BLAS node must have implementation type ROCMBLAS with data transfers
196
    if (this->blas_node_.implementation_type().value() != rocm::ImplementationType_ROCMWithTransfers.value()) {
6✔
197
        return false;
2✔
198
    }
2✔
199

200
    // Restrict to BLAS nodes in their own block
201
    auto& dfg = this->blas_node_.get_parent();
4✔
202
    if (dfg.nodes().size() != dfg.in_degree(this->blas_node_) + dfg.out_degree(this->blas_node_) + 1) {
4✔
203
        return false;
×
204
    }
×
205

206
    // Supported BLAS nodes
207
    if (dynamic_cast<math::blas::DotNode*>(&this->blas_node_)) {
4✔
208
        return true;
2✔
209
    } else if (dynamic_cast<math::blas::GEMMNode*>(&this->blas_node_)) {
2✔
210
        return true;
2✔
211
    } else {
2✔
212
        return false;
×
213
    }
×
214
}
4✔
215

216
void ROCBLASDataTransferExtraction::
217
    apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
2✔
218
    // Get data flow graph and block
219
    auto& dfg = this->blas_node_.get_parent();
2✔
220
    auto* block = dynamic_cast<structured_control_flow::Block*>(dfg.get_parent());
2✔
221
    assert(block);
2✔
222

223
    // Get sequence
224
    auto* sequence = dynamic_cast<structured_control_flow::Sequence*>(block->get_parent());
2✔
225
    assert(sequence);
2✔
226

227
    // Determine type
228
    types::PrimitiveType precision;
2✔
229
    switch (this->blas_node_.precision()) {
2✔
230
        case math::blas::h:
×
231
            precision = types::PrimitiveType::Half;
×
232
            break;
×
233
        case math::blas::s:
1✔
234
            precision = types::PrimitiveType::Float;
1✔
235
            break;
1✔
236
        case math::blas::d:
1✔
237
            precision = types::PrimitiveType::Double;
1✔
238
            break;
1✔
239
        default:
×
240
            throw InvalidSDFGException("ROCBLASDataTransferExtraction: Unsupported precision");
×
241
    }
2✔
242
    types::Scalar base_type(precision);
2✔
243
    types::Pointer type(base_type);
2✔
244

245
    // Capture in and out accesses
246
    std::unordered_map<std::string, data_flow::AccessNode&> in_access, out_access;
2✔
247
    for (auto& iedge : dfg.in_edges(this->blas_node_)) {
7✔
248
        in_access.insert({iedge.dst_conn(), static_cast<data_flow::AccessNode&>(iedge.src())});
7✔
249
    }
7✔
250
    for (auto& oedge : dfg.out_edges(this->blas_node_)) {
2✔
251
        out_access.insert({oedge.src_conn(), static_cast<data_flow::AccessNode&>(oedge.dst())});
1✔
252
    }
1✔
253

254
    if (auto* dot_node = dynamic_cast<math::blas::DotNode*>(&this->blas_node_)) {
2✔
255
        auto x_size = symbolic::mul(
1✔
256
            symbolic::add(symbolic::mul(symbolic::sub(dot_node->n(), symbolic::one()), dot_node->incx()), symbolic::one()),
1✔
257
            types::get_contiguous_element_size(type, true)
1✔
258
        );
1✔
259
        auto y_size = symbolic::mul(
1✔
260
            symbolic::add(symbolic::mul(symbolic::sub(dot_node->n(), symbolic::one()), dot_node->incy()), symbolic::one()),
1✔
261
            types::get_contiguous_element_size(type, true)
1✔
262
        );
1✔
263
        auto dx = this->create_device_container(builder, type, x_size);
1✔
264
        auto dy = this->create_device_container(builder, type, y_size);
1✔
265

266
        this->create_copy_to_device_with_allocation(
1✔
267
            builder, *sequence, *block, in_access.at("__x").data(), dx, x_size, type
1✔
268
        );
1✔
269
        this->create_copy_to_device_with_allocation(
1✔
270
            builder, *sequence, *block, in_access.at("__y").data(), dy, y_size, type
1✔
271
        );
1✔
272

273
        this->create_deallocate(builder, *sequence, *block, dx, type);
1✔
274
        this->create_deallocate(builder, *sequence, *block, dy, type);
1✔
275

276
        in_access.at("__x").data(dx);
1✔
277
        in_access.at("__y").data(dy);
1✔
278
    } else if (auto* gemm_node = dynamic_cast<math::blas::GEMMNode*>(&this->blas_node_)) {
1✔
279
        auto elem_size = types::get_contiguous_element_size(type, true);
1✔
280
        auto a_size = symbolic::mul(symbolic::mul(gemm_node->m(), gemm_node->k()), elem_size);
1✔
281
        auto b_size = symbolic::mul(symbolic::mul(gemm_node->k(), gemm_node->n()), elem_size);
1✔
282
        auto c_size = symbolic::mul(symbolic::mul(gemm_node->m(), gemm_node->n()), elem_size);
1✔
283

284
        auto dA = this->create_device_container(builder, type, a_size);
1✔
285
        auto dB = this->create_device_container(builder, type, b_size);
1✔
286
        auto dC = this->create_device_container(builder, type, c_size);
1✔
287

288
        this->create_copy_to_device_with_allocation(
1✔
289
            builder, *sequence, *block, in_access.at("__A").data(), dA, a_size, type
1✔
290
        );
1✔
291
        this->create_copy_to_device_with_allocation(
1✔
292
            builder, *sequence, *block, in_access.at("__B").data(), dB, b_size, type
1✔
293
        );
1✔
294
        auto c_ptr = in_access.at("__C").data();
1✔
295
        this->create_copy_to_device_with_allocation(builder, *sequence, *block, c_ptr, dC, c_size, type);
1✔
296

297
        this->create_copy_from_device_with_deallocation(builder, *sequence, *block, c_ptr, dC, c_size, type);
1✔
298
        this->create_deallocate(builder, *sequence, *block, dA, type);
1✔
299
        this->create_deallocate(builder, *sequence, *block, dB, type);
1✔
300

301
        in_access.at("__A").data(dA);
1✔
302
        in_access.at("__B").data(dB);
1✔
303
        in_access.at("__C").data(dC);
1✔
304
    } else {
1✔
305
        throw InvalidSDFGException("ROCBLASDataTransferExtraction: Unsupported BLAS type");
×
306
    }
×
307

308
    // Change the implementation type to ROCMBLAS without data transfers
309
    this->blas_node_.implementation_type() = rocm::ImplementationType_ROCMWithoutTransfers;
2✔
310
}
2✔
311

312
void ROCBLASDataTransferExtraction::to_json(nlohmann::json& j) const {
2✔
313
    j["transformation_type"] = this->name();
2✔
314
    j["parameters"] = nlohmann::json::object();
2✔
315
    j["subgraph"] = {{"0", {{"element_id", this->blas_node_.element_id()}, {"type", "unknown"}}}};
2✔
316
}
2✔
317

318
ROCBLASDataTransferExtraction ROCBLASDataTransferExtraction::
319
    from_json(builder::StructuredSDFGBuilder& builder, const nlohmann::json& j) {
2✔
320
    size_t blas_node_id;
2✔
321
    const auto& node_desc = j.at("subgraph").at("0");
2✔
322
    blas_node_id = node_desc.at("element_id").get<size_t>();
2✔
323

324
    auto* blas_node_element = builder.find_element_by_id(blas_node_id);
2✔
325
    if (!blas_node_element) {
2✔
UNCOV
326
        throw transformations::
×
UNCOV
327
            InvalidTransformationDescriptionException("Element with ID " + std::to_string(blas_node_id) + " not found");
×
UNCOV
328
    }
×
329
    auto* blas_node = dynamic_cast<math::blas::BLASNode*>(blas_node_element);
2✔
330
    if (!blas_node) {
2✔
331
        throw transformations::InvalidTransformationDescriptionException(
×
UNCOV
332
            "Element with ID " + std::to_string(blas_node_id) + " is not a BLASNode"
×
UNCOV
333
        );
×
334
    }
×
335

336
    return ROCBLASDataTransferExtraction(*blas_node);
2✔
337
}
2✔
338

339
} // namespace rocm
340
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc