• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23490543373

24 Mar 2026 12:56PM UTC coverage: 64.456% (+0.2%) from 64.295%
23490543373

Pull #605

github

web-flow
Merge 28bc2690b into e56781552
Pull Request #605: Move einsum support

1303 of 1918 new or added lines in 14 files covered. (67.94%)

45 existing lines in 3 files now uncovered.

27952 of 43366 relevant lines covered (64.46%)

392.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/opt/src/targets/rocm/rocm_data_offloading_node.cpp
1
#include "sdfg/targets/rocm/rocm_data_offloading_node.h"
2

3
#include <cstddef>
4
#include <memory>
5
#include <nlohmann/json_fwd.hpp>
6

7
#include "sdfg/analysis/loop_analysis.h"
8
#include "sdfg/codegen/code_snippet_factory.h"
9
#include "sdfg/codegen/dispatchers/block_dispatcher.h"
10
#include "sdfg/codegen/instrumentation/instrumentation_info.h"
11
#include "sdfg/codegen/language_extension.h"
12
#include "sdfg/codegen/utils.h"
13
#include "sdfg/data_flow/data_flow_graph.h"
14
#include "sdfg/data_flow/data_flow_node.h"
15
#include "sdfg/data_flow/library_node.h"
16
#include "sdfg/function.h"
17
#include "sdfg/graph/graph.h"
18
#include "sdfg/symbolic/symbolic.h"
19
#include "sdfg/targets/offloading/data_offloading_node.h"
20
#include "sdfg/targets/rocm/rocm.h"
21
#include "symengine/symengine_rcp.h"
22

23
namespace sdfg {
24
namespace rocm {
25

26
ROCMDataOffloadingNode::ROCMDataOffloadingNode(
27
    size_t element_id,
28
    const DebugInfo& debug_info,
29
    const graph::Vertex vertex,
30
    data_flow::DataFlowGraph& parent,
31
    symbolic::Expression size,
32
    symbolic::Expression device_id,
33
    offloading::DataTransferDirection transfer_direction,
34
    offloading::BufferLifecycle buffer_lifecycle
35
)
36
    : offloading::DataOffloadingNode(
×
37
          element_id,
×
38
          debug_info,
×
39
          vertex,
×
40
          parent,
×
41
          LibraryNodeType_ROCM_Offloading,
×
42
          {},
×
43
          {},
×
44
          transfer_direction,
×
45
          buffer_lifecycle,
×
46
          size
×
47
      ),
×
48
      device_id_(device_id) {
×
49
    if (!is_NONE(transfer_direction)) {
×
50
        this->inputs_.push_back("_src");
×
51
        this->outputs_.push_back("_dst");
×
52
    } else if (is_ALLOC(buffer_lifecycle)) {
×
53
        this->outputs_.push_back("_ret");
×
54
    } else if (is_FREE(buffer_lifecycle)) {
×
55
        this->inputs_.push_back("_ptr");
×
56
        this->outputs_.push_back("_ptr");
×
57
    }
×
58
}
×
59

60
void ROCMDataOffloadingNode::validate(const Function& function) const {
×
61
    // Prevent copy-in and free
62
    if (this->is_h2d() && this->is_free()) {
×
63
        throw InvalidSDFGException("ROCMDataOffloadingNode: Combination copy-in and free is not allowed");
×
64
    }
×
65

66
    // Prevent copy-out and alloc
67
    if (this->is_d2h() && this->is_alloc()) {
×
68
        throw InvalidSDFGException("ROCMDataOffloadingNode: Combination copy-out and alloc is not allowed");
×
69
    }
×
70
}
×
71

72
const symbolic::Expression ROCMDataOffloadingNode::device_id() const { return this->device_id_; }
×
73

74
std::unique_ptr<data_flow::DataFlowNode> ROCMDataOffloadingNode::
75
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
76
    return std::make_unique<ROCMDataOffloadingNode>(
×
77
        element_id,
×
78
        this->debug_info(),
×
79
        vertex,
×
80
        parent,
×
81
        this->size(),
×
82
        this->device_id(),
×
83
        this->transfer_direction(),
×
84
        this->buffer_lifecycle()
×
85
    );
×
86
}
×
87

88
symbolic::SymbolSet ROCMDataOffloadingNode::symbols() const {
×
89
    if (this->device_id().is_null()) {
×
90
        return offloading::DataOffloadingNode::symbols();
×
91
    }
×
92
    auto symbols = offloading::DataOffloadingNode::symbols();
×
93
    auto device_id_atoms = symbolic::atoms(this->device_id());
×
94
    symbols.insert(device_id_atoms.begin(), device_id_atoms.end());
×
95
    return symbols;
×
96
}
×
97

98
void ROCMDataOffloadingNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
99
    offloading::DataOffloadingNode::replace(old_expression, new_expression);
×
100
    this->device_id_ = symbolic::subs(this->device_id_, old_expression, new_expression);
×
101
}
×
102

103
bool ROCMDataOffloadingNode::blocking() const { return true; }
×
104

105
bool ROCMDataOffloadingNode::redundant_with(const offloading::DataOffloadingNode& other) const {
×
106
    if (!offloading::DataOffloadingNode::redundant_with(other)) {
×
107
        return false;
×
108
    }
×
109

110
    auto& other_node = static_cast<const ROCMDataOffloadingNode&>(other);
×
111
    if (!symbolic::null_safe_eq(this->device_id(), other_node.device_id())) {
×
112
        return false;
×
113
    }
×
114

115
    return true;
×
116
}
×
117

118
bool ROCMDataOffloadingNode::equal_with(const offloading::DataOffloadingNode& other) const {
×
119
    if (!offloading::DataOffloadingNode::equal_with(other)) {
×
120
        return false;
×
121
    }
×
122

123
    auto& other_node = static_cast<const ROCMDataOffloadingNode&>(other);
×
124
    if (!symbolic::null_safe_eq(this->device_id(), other_node.device_id())) {
×
125
        return false;
×
126
    }
×
127

128
    return true;
×
129
}
×
130

131
ROCMDataOffloadingNodeDispatcher::ROCMDataOffloadingNodeDispatcher(
132
    codegen::LanguageExtension& language_extension,
133
    const Function& function,
134
    const data_flow::DataFlowGraph& data_flow_graph,
135
    const data_flow::LibraryNode& node
136
)
UNCOV
137
    : codegen::LibraryNodeDispatcher(language_extension, function, data_flow_graph, node) {}
×
138

139
void ROCMDataOffloadingNodeDispatcher::dispatch_code(
140
    codegen::PrettyPrinter& stream,
141
    codegen::PrettyPrinter& globals_stream,
142
    codegen::CodeSnippetFactory& library_snippet_factory
UNCOV
143
) {
×
UNCOV
144
    auto& offloading_node = static_cast<const ROCMDataOffloadingNode&>(this->node_);
×
145

UNCOV
146
    stream << "hipError_t err;" << std::endl;
×
147

148
    if (offloading_node.is_alloc()) {
×
149
        stream << "err = hipMalloc(&" << offloading_node.output(0) << ", "
×
UNCOV
150
               << this->language_extension_.expression(offloading_node.size()) << ");" << std::endl;
×
151
        rocm_error_checking(stream, this->language_extension_, "err");
×
UNCOV
152
    }
×
153

154
    if (offloading_node.is_h2d()) {
×
155
        stream << "err = hipMemcpy(" << offloading_node.output(0) << ", " << offloading_node.input(0) << ", "
×
156
               << this->language_extension_.expression(offloading_node.size()) << ", hipMemcpyHostToDevice);"
×
157
               << std::endl;
×
UNCOV
158
        rocm_error_checking(stream, this->language_extension_, "err");
×
159
    } else if (offloading_node.is_d2h()) {
×
160
        stream << "err = hipMemcpy(" << offloading_node.output(0) << ", " << offloading_node.input(0) << ", "
×
161
               << this->language_extension_.expression(offloading_node.size()) << ", hipMemcpyDeviceToHost);"
×
162
               << std::endl;
×
163
        rocm_error_checking(stream, this->language_extension_, "err");
×
164
    }
×
165

166
    if (offloading_node.is_free()) {
×
167
        stream << "err = hipFree(" << offloading_node.input(0) << ");" << std::endl;
×
168
        rocm_error_checking(stream, this->language_extension_, "err");
×
169
    }
×
UNCOV
170
}
×
171

172
codegen::InstrumentationInfo ROCMDataOffloadingNodeDispatcher::instrumentation_info() const {
×
173
    auto& rocm_node = static_cast<const ROCMDataOffloadingNode&>(node_);
×
174
    if (rocm_node.is_d2h()) {
×
175
        return codegen::InstrumentationInfo(
×
UNCOV
176
            node_.element_id(),
×
177
            codegen::ElementType_D2HTransfer,
×
178
            TargetType_ROCM,
×
179
            analysis::LoopInfo{},
×
180
            {{"pcie_bytes", language_extension_.expression(rocm_node.size())}}
×
181
        );
×
182
    } else if (rocm_node.is_h2d()) {
×
183
        return codegen::InstrumentationInfo(
×
184
            node_.element_id(),
×
185
            codegen::ElementType_H2DTransfer,
×
186
            TargetType_ROCM,
×
187
            analysis::LoopInfo{},
×
188
            {{"pcie_bytes", language_extension_.expression(rocm_node.size())}}
×
189
        );
×
190
    } else {
×
191
        return codegen::LibraryNodeDispatcher::instrumentation_info();
×
192
    }
×
193
}
×
194

195
nlohmann::json ROCMDataOffloadingNodeSerializer::serialize(const sdfg::data_flow::LibraryNode& library_node) {
×
196
    const auto& node = static_cast<const ROCMDataOffloadingNode&>(library_node);
×
197
    nlohmann::json j;
×
198

199
    // Library node
200
    j["type"] = "library_node";
×
201
    j["element_id"] = library_node.element_id();
×
202

203
    // Debug info
UNCOV
204
    auto& debug_info = library_node.debug_info();
×
205
    j["has"] = debug_info.has();
×
206
    j["filename"] = debug_info.filename();
×
UNCOV
207
    j["start_line"] = debug_info.start_line();
×
UNCOV
208
    j["start_column"] = debug_info.start_column();
×
209
    j["end_line"] = debug_info.end_line();
×
210
    j["end_column"] = debug_info.end_column();
×
211

212
    // Library node properties
213
    j["code"] = std::string(library_node.code().value());
×
214

215
    // Offloading node properties
UNCOV
216
    sdfg::serializer::JSONSerializer serializer;
×
UNCOV
217
    if (node.size().is_null()) {
×
218
        j["size"] = nlohmann::json::value_t::null;
×
UNCOV
219
    } else {
×
UNCOV
220
        j["size"] = serializer.expression(node.size());
×
221
    }
×
222
    j["device_id"] = serializer.expression(node.device_id());
×
223
    j["transfer_direction"] = static_cast<int8_t>(node.transfer_direction());
×
224
    j["buffer_lifecycle"] = static_cast<int8_t>(node.buffer_lifecycle());
×
225

226
    return j;
×
227
}
×
228

229
data_flow::LibraryNode& ROCMDataOffloadingNodeSerializer::deserialize(
230
    const nlohmann::json& j, sdfg::builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
231
) {
×
232
    auto code = j["code"].get<std::string>();
×
UNCOV
233
    if (code != LibraryNodeType_ROCM_Offloading.value()) {
×
UNCOV
234
        throw std::runtime_error("Invalid library node code");
×
UNCOV
235
    }
×
236

237
    sdfg::serializer::JSONSerializer serializer;
×
238
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
239

240
    symbolic::Expression size;
×
UNCOV
241
    if (!j.contains("size") || j.at("size").is_null()) {
×
242
        size = SymEngine::null;
×
243
    } else {
×
UNCOV
244
        size = symbolic::parse(j.at("size"));
×
245
    }
×
246
    SymEngine::Expression device_id(j.at("device_id"));
×
247
    auto transfer_direction = static_cast<offloading::DataTransferDirection>(j["transfer_direction"].get<int8_t>());
×
248
    auto buffer_lifecycle = static_cast<offloading::BufferLifecycle>(j["buffer_lifecycle"].get<int8_t>());
×
249

250
    return builder.add_library_node<
×
251
        ROCMDataOffloadingNode>(parent, debug_info, size, device_id, transfer_direction, buffer_lifecycle);
×
252
}
×
253

254
} // namespace rocm
255
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc