• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

celerity / celerity-runtime / 9907467165

12 Jul 2024 11:55AM UTC coverage: 93.527% (-1.1%) from 94.676%
9907467165

Pull #255

github

fknorr
Upstream multi-device selection & tests

Co-authored-by: Philip Salzmann <philip.salzmann@uibk.ac.at>
Pull Request #255: [IDAG] Upstream Multi-Device Selection + Tests

3135 of 3561 branches covered (88.04%)

Branch coverage included in aggregate %.

107 of 109 new or added lines in 2 files covered. (98.17%)

65 existing lines in 1 file now uncovered.

7094 of 7376 relevant lines covered (96.18%)

199351.17 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

39.23
/include/device_queue.h
1
#pragma once
2

3
#include <algorithm>
4
#include <memory>
5
#include <type_traits>
6
#include <variant>
7

8
#include <CL/sycl.hpp>
9

10
#include "backend/backend.h"
11
#include "config.h"
12
#include "log.h"
13
#include "workaround.h"
14

15
namespace celerity {
16
namespace detail {
17

18
        struct auto_select_device {};
19
        using device_selector = std::function<int(const sycl::device&)>;
20
        using device_or_selector = std::variant<auto_select_device, sycl::device, device_selector>;
21

22
        class task;
23

24
        struct device_allocation {
25
                void* ptr = nullptr;
26
                size_t size_bytes = 0;
27
        };
28

29
        class allocation_error : public std::runtime_error {
30
          public:
31
                allocation_error(const std::string& msg) : std::runtime_error(msg) {}
3✔
32
        };
33

34
        /**
35
         * The @p device_queue wraps the actual SYCL queue and is used to submit kernels.
36
         */
37
        class device_queue {
38
          public:
39
                /**
40
                 * @brief Initializes the @p device_queue, selecting an appropriate device in the process.
41
                 *
42
                 * @param cfg The configuration is used to select the appropriate SYCL device.
43
                 * @param user_device_or_selector Optionally a device (which will take precedence over any configuration) or a device selector can be provided.
44
                 */
45
                void init(const config& cfg, const device_or_selector& user_device_or_selector);
46

47
                /**
48
                 * @brief Executes the kernel associated with task @p ctsk over the chunk @p chnk.
49
                 */
50
                template <typename Fn>
51
                cl::sycl::event submit(Fn&& fn) {
252✔
52
                        auto evt = m_sycl_queue->submit([fn = std::forward<Fn>(fn)](cl::sycl::handler& sycl_handler) { fn(sycl_handler); });
504✔
53
#if CELERITY_WORKAROUND(HIPSYCL)
54
#pragma GCC diagnostic push
55
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
56
                        // hipSYCL does not guarantee that command groups are actually scheduled until an explicit await operation, which we cannot insert without
57
                        // blocking the executor loop (see https://github.com/illuhad/hipSYCL/issues/599). Instead, we explicitly flush the queue to be able to continue
58
                        // using our polling-based approach.
59
                        m_sycl_queue->get_context().hipSYCL_runtime()->dag().flush_async();
60
#pragma GCC diagnostic pop
61
#endif
62
                        return evt;
252✔
63
                }
64

65
                template <typename T>
66
                [[nodiscard]] device_allocation malloc(const size_t count) {
201✔
67
                        const size_t size_bytes = count * sizeof(T);
201✔
68
                        assert(m_sycl_queue != nullptr);
201✔
69
                        assert(m_global_mem_allocated_bytes + size_bytes < m_global_mem_total_size_bytes);
201✔
70
                        CELERITY_DEBUG("Allocating {} bytes on device", size_bytes);
402✔
71
                        T* ptr = nullptr;
201✔
72
                        try {
73
                                ptr = sycl::aligned_alloc_device<T>(alignof(T), count, *m_sycl_queue);
201✔
74
                        } catch(sycl::exception& e) {
×
75
                                CELERITY_CRITICAL("sycl::aligned_alloc_device failed with exception: {}", e.what());
×
76
                                ptr = nullptr;
×
77
                        }
78
                        if(ptr == nullptr) {
201✔
79
                                throw allocation_error(fmt::format("Allocation of {} bytes failed; likely out of memory. Currently allocated: {} out of {} bytes.",
3✔
80
                                    count * sizeof(T), m_global_mem_allocated_bytes, m_global_mem_total_size_bytes));
2✔
81
                        }
82
                        m_global_mem_allocated_bytes += size_bytes;
200✔
83
                        return device_allocation{ptr, size_bytes};
200✔
84
                }
85

86
                void free(device_allocation alloc) {
205✔
87
                        assert(m_sycl_queue != nullptr);
205✔
88
                        assert(alloc.size_bytes <= m_global_mem_allocated_bytes);
205✔
89
                        assert(alloc.ptr != nullptr || alloc.size_bytes == 0);
205✔
90
                        CELERITY_DEBUG("Freeing {} bytes on device", alloc.size_bytes);
410✔
91
                        if(alloc.size_bytes != 0) { sycl::free(alloc.ptr, *m_sycl_queue); }
205✔
92
                        m_global_mem_allocated_bytes -= alloc.size_bytes;
205✔
93
                }
205✔
94

95
                size_t get_global_memory_total_size_bytes() const { return m_global_mem_total_size_bytes; }
351✔
96

97
                size_t get_global_memory_allocated_bytes() const { return m_global_mem_allocated_bytes; }
340✔
98

99
                /**
100
                 * @brief Waits until all currently submitted operations have completed.
101
                 */
102
                void wait() { m_sycl_queue->wait_and_throw(); }
189✔
103

104
                /**
105
                 * @brief Returns whether device profiling is enabled.
106
                 */
107
                bool is_profiling_enabled() const { return m_device_profiling_enabled; }
470✔
108

109
                cl::sycl::queue& get_sycl_queue() const {
824✔
110
                        assert(m_sycl_queue != nullptr);
824✔
111
                        return *m_sycl_queue;
824✔
112
                }
113

114
          private:
115
                size_t m_global_mem_total_size_bytes = 0;
116
                size_t m_global_mem_allocated_bytes = 0;
117
                std::unique_ptr<cl::sycl::queue> m_sycl_queue;
118
                bool m_device_profiling_enabled = false;
119

120
                void handle_async_exceptions(cl::sycl::exception_list el) const;
121
        };
122

123
        // Try to find a platform that can provide a unique device for each node using a device selector.
124
        template <typename DeviceT, typename PlatformT, typename SelectorT>
UNCOV
125
        bool try_find_device_per_node(
×
126
            std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
UNCOV
127
                std::vector<std::tuple<DeviceT, size_t>> devices_with_platform_idx;
×
UNCOV
128
                for(size_t i = 0; i < platforms.size(); ++i) {
×
UNCOV
129
                        auto&& platform = platforms[i];
×
UNCOV
130
                        for(auto device : platform.get_devices()) {
×
UNCOV
131
                                if(selector(device) == -1) { continue; }
×
UNCOV
132
                                devices_with_platform_idx.emplace_back(device, i);
×
133
                        }
134
                }
135

UNCOV
136
                std::stable_sort(devices_with_platform_idx.begin(), devices_with_platform_idx.end(),
×
UNCOV
137
                    [selector](const auto& a, const auto& b) { return selector(std::get<0>(a)) > selector(std::get<0>(b)); });
×
UNCOV
138
                bool same_platform = true;
×
UNCOV
139
                bool same_device_type = true;
×
UNCOV
140
                if(devices_with_platform_idx.size() >= host_cfg.node_count) {
×
UNCOV
141
                        auto [device_from_platform, idx] = devices_with_platform_idx[0];
×
UNCOV
142
                        const auto platform = device_from_platform.get_platform();
×
UNCOV
143
                        const auto device_type = device_from_platform.template get_info<sycl::info::device::device_type>();
×
144

UNCOV
145
                        for(size_t i = 1; i < host_cfg.node_count; ++i) {
×
UNCOV
146
                                auto [device_from_platform, idx] = devices_with_platform_idx[i];
×
UNCOV
147
                                if(device_from_platform.get_platform() != platform) { same_platform = false; }
×
UNCOV
148
                                if(device_from_platform.template get_info<sycl::info::device::device_type>() != device_type) { same_device_type = false; }
×
149
                        }
150

UNCOV
151
                        if(!same_platform || !same_device_type) { CELERITY_WARN("Selected devices are of different type and/or do not belong to the same platform"); }
×
152

UNCOV
153
                        auto [selected_device_from_platform, selected_idx] = devices_with_platform_idx[host_cfg.local_rank];
×
UNCOV
154
                        how_selected = fmt::format("device selector specified: platform {}, device {}", selected_idx, host_cfg.local_rank);
×
UNCOV
155
                        device = selected_device_from_platform;
×
UNCOV
156
                        return true;
×
UNCOV
157
                }
×
158

UNCOV
159
                return false;
×
UNCOV
160
        }
×
161

162
        // Try to find a platform that can provide a unique device for each node.
163
        template <typename DeviceT, typename PlatformT>
164
        bool try_find_device_per_node(
268✔
165
            std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
166
                for(size_t i = 0; i < platforms.size(); ++i) {
536!
167
                        auto&& platform = platforms[i];
268✔
168
                        std::vector<DeviceT> platform_devices;
268✔
169

170
                        platform_devices = platform.get_devices(type);
268✔
171
                        if(platform_devices.size() >= host_cfg.node_count) {
268!
172
                                how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
536✔
173
                                device = platform_devices[host_cfg.local_rank];
268✔
174
                                return true;
268✔
175
                        }
176
                }
177

UNCOV
178
                return false;
×
179
        }
180

181
        template <typename DeviceT, typename PlatformT, typename SelectorT>
UNCOV
182
        bool try_find_one_device(
×
183
            std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
UNCOV
184
                std::vector<DeviceT> platform_devices;
×
UNCOV
185
                for(auto& p : platforms) {
×
UNCOV
186
                        auto p_devices = p.get_devices();
×
UNCOV
187
                        platform_devices.insert(platform_devices.end(), p_devices.begin(), p_devices.end());
×
188
                }
189

UNCOV
190
                std::stable_sort(platform_devices.begin(), platform_devices.end(), [selector](const auto& a, const auto& b) { return selector(a) > selector(b); });
×
UNCOV
191
                if(!platform_devices.empty()) {
×
UNCOV
192
                        if(selector(platform_devices[0]) == -1) { return false; }
×
UNCOV
193
                        device = platform_devices[0];
×
UNCOV
194
                        return true;
×
195
                }
196

197
                return false;
×
UNCOV
198
        };
×
199

200
        template <typename DeviceT, typename PlatformT>
UNCOV
201
        bool try_find_one_device(
×
202
            std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
UNCOV
203
                for(auto& p : platforms) {
×
UNCOV
204
                        for(auto& d : p.get_devices(type)) {
×
UNCOV
205
                                device = d;
×
UNCOV
206
                                return true;
×
207
                        }
208
                }
209

UNCOV
210
                return false;
×
211
        };
212

213

214
        template <typename DevicePtrOrSelector, typename PlatformT>
215
        auto pick_device(const config& cfg, const DevicePtrOrSelector& user_device_or_selector, const std::vector<PlatformT>& platforms) {
270✔
216
                using DeviceT = typename decltype(std::declval<PlatformT&>().get_devices())::value_type;
217

218
                constexpr bool user_device_provided = std::is_same_v<DevicePtrOrSelector, DeviceT>;
270✔
219
                constexpr bool device_selector_provided = std::is_invocable_r_v<int, DevicePtrOrSelector, DeviceT>;
270✔
220
                constexpr bool auto_select = std::is_same_v<auto_select_device, DevicePtrOrSelector>;
270✔
221
                static_assert(
222
                    user_device_provided ^ device_selector_provided ^ auto_select, "pick_device requires either a device, a selector, or the auto_select_device tag");
223

224
                DeviceT device;
270✔
225
                std::string how_selected = "automatically selected";
540✔
226
                if constexpr(user_device_provided) {
227
                        device = user_device_or_selector;
2✔
228
                        how_selected = "specified by user";
2✔
229
                } else {
230
                        const auto device_cfg = cfg.get_device_config();
268✔
231
                        if(device_cfg != std::nullopt) {
268!
UNCOV
232
                                how_selected = fmt::format("set by CELERITY_DEVICES: platform {}, device {}", device_cfg->platform_id, device_cfg->device_id);
×
UNCOV
233
                                CELERITY_DEBUG("{} platforms available", platforms.size());
×
UNCOV
234
                                if(device_cfg->platform_id >= platforms.size()) {
×
UNCOV
235
                                        throw std::runtime_error(fmt::format("Invalid platform id {}: Only {} platforms available", device_cfg->platform_id, platforms.size()));
×
236
                                }
UNCOV
237
                                const auto devices = platforms[device_cfg->platform_id].get_devices();
×
UNCOV
238
                                if(device_cfg->device_id >= devices.size()) {
×
UNCOV
239
                                        throw std::runtime_error(fmt::format(
×
UNCOV
240
                                            "Invalid device id {}: Only {} devices available on platform {}", device_cfg->device_id, devices.size(), device_cfg->platform_id));
×
241
                                }
UNCOV
242
                                device = devices[device_cfg->device_id];
×
UNCOV
243
                        } else {
×
244
                                const auto host_cfg = cfg.get_host_config();
268✔
245

246
                                if constexpr(!device_selector_provided) {
247
                                        // Try to find a unique GPU per node.
248
                                        if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)) {
268!
UNCOV
249
                                                if(try_find_device_per_node(how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
×
UNCOV
250
                                                        CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
×
251
                                                } else {
UNCOV
252
                                                        CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
×
253
                                                        // Just use the first available device. Prefer GPUs, but settle for anything.
UNCOV
254
                                                        if(!try_find_one_device(how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)
×
UNCOV
255
                                                            && !try_find_one_device(how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
×
UNCOV
256
                                                                throw std::runtime_error("Automatic device selection failed: No device available");
×
257
                                                        }
258
                                                }
259
                                        }
260
                                } else {
261
                                        // Try to find a unique device per node using a selector.
UNCOV
262
                                        if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, user_device_or_selector)) {
×
UNCOV
263
                                                CELERITY_WARN("No suitable platform found that can provide {} devices that match the specified device selector, and "
×
264
                                                              "CELERITY_DEVICES not set",
265
                                                    host_cfg.node_count);
266
                                                // Use the first available device according to the selector, but fails if no such device is found.
UNCOV
267
                                                if(!try_find_one_device(how_selected, device, platforms, host_cfg, user_device_or_selector)) {
×
UNCOV
268
                                                        throw std::runtime_error("Device selection with device selector failed: No device available");
×
269
                                                }
270
                                        }
271
                                }
272
                        }
273
                }
274

275
                const auto platform_name = device.get_platform().template get_info<sycl::info::platform::name>();
540✔
276
                const auto device_name = device.template get_info<sycl::info::device::name>();
270✔
277
                CELERITY_INFO("Using platform '{}', device '{}' ({})", platform_name, device_name, how_selected);
540✔
278

279
                if constexpr(std::is_same_v<DeviceT, sycl::device>) {
280
                        if(backend::get_effective_type(device) == backend::type::generic) {
270!
281
                                if(backend::get_type(device) == backend::type::unknown) {
270!
282
                                        CELERITY_WARN("No backend specialization available for selected platform '{}', falling back to generic. Performance may be degraded.",
540✔
283
                                            device.get_platform().template get_info<sycl::info::platform::name>());
284
                                } else {
285
                                        CELERITY_WARN(
×
286
                                            "Selected platform '{}' is compatible with specialized {} backend, but it has not been compiled. Performance may be degraded.",
287
                                            device.get_platform().template get_info<sycl::info::platform::name>(), backend::get_name(backend::get_type(device)));
288
                                }
289
                        } else {
290
                                CELERITY_DEBUG("Using {} backend for selected platform '{}'.", backend::get_name(backend::get_effective_type(device)),
×
291
                                    device.get_platform().template get_info<sycl::info::platform::name>());
292
                        }
293
                }
294

295
                return device;
540✔
296
        }
270✔
297

298
} // namespace detail
299
} // namespace celerity
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc