• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

celerity / celerity-runtime / 9582564578

19 Jun 2024 01:06PM UTC coverage: 93.536% (-1.2%) from 94.694%
9582564578

Pull #255

github

fknorr
Upstream multi-device selection & tests
Pull Request #255: [IDAG] Upstream Multi-Device Selection + Tests

3135 of 3561 branches covered (88.04%)

Branch coverage included in aggregate %.

107 of 109 new or added lines in 2 files covered. (98.17%)

66 existing lines in 2 files now uncovered.

7095 of 7376 relevant lines covered (96.19%)

189221.47 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.38
/include/device_selection.h
1
#pragma once
2

3
#include <functional>
4
#include <string>
5
#include <type_traits>
6
#include <variant>
7
#include <vector>
8

9
#include "config.h"
10
#include "log.h"
11

12
#include <sycl/sycl.hpp>
13

14
namespace celerity::detail {
15

16
// TODO these are required by distr_queue.h, but we don't want to pull all include dependencies of the pick_devices implementation into user code!
17
struct auto_select_devices {};
18
using device_selector = std::function<int(const sycl::device&)>;
19
using devices_or_selector = std::variant<auto_select_devices, std::vector<sycl::device>, device_selector>;
20

21
template <typename DeviceT>
22
void check_required_device_aspects(const DeviceT& device) {
309✔
23
        if(!device.has(sycl::aspect::usm_device_allocations)) { throw std::runtime_error("device does not support USM device allocations"); }
309✔
24
        if(!device.has(sycl::aspect::usm_host_allocations)) { throw std::runtime_error("device does not support USM host allocations"); }
293✔
25
}
292✔
26

27
template <typename DevicesOrSelector, typename PlatformT>
28
auto pick_devices(const host_config& cfg, const DevicesOrSelector& user_devices_or_selector, const std::vector<PlatformT>& platforms) {
63✔
29
        using DeviceT = typename decltype(std::declval<PlatformT&>().get_devices())::value_type;
30
        using BackendT = decltype(std::declval<DeviceT&>().get_backend());
31

32
        constexpr bool user_devices_provided = std::is_same_v<DevicesOrSelector, std::vector<DeviceT>>;
63✔
33
        constexpr bool device_selector_provided = std::is_invocable_r_v<int, DevicesOrSelector, DeviceT>;
63✔
34
        constexpr bool auto_select = std::is_same_v<auto_select_devices, DevicesOrSelector>;
63✔
35
        static_assert(user_devices_provided ^ device_selector_provided ^ auto_select,
36
            "pick_device requires either a list of devices, a selector, or the auto_select_devices tag");
37

38
        std::vector<DeviceT> selected_devices;
63✔
39
        std::string how_selected;
63✔
40

41
        if(cfg.node_count > 1) {
63✔
42
                CELERITY_WARN("Celerity detected more than one node (MPI rank) on this host, which is not recommended. Will attempt to distribute local devices evenly "
74✔
43
                              "across nodes.");
44
        }
45

46
        if constexpr(user_devices_provided) {
47
                const auto devices = user_devices_or_selector;
6✔
48
                if(devices.empty()) { throw std::runtime_error("Device selection failed: The user-provided list of devices is empty"); }
6✔
49
                auto backend = devices[0].get_backend();
5✔
50
                for(size_t i = 0; i < devices.size(); ++i) {
15✔
51
                        if(devices[i].get_backend() != backend) {
12✔
52
                                throw std::runtime_error("Device selection failed: The user-provided list of devices contains devices from different backends");
1✔
53
                        }
54
                        try {
55
                                check_required_device_aspects(devices[i]);
11✔
56
                        } catch(std::runtime_error& e) {
2!
57
                                throw std::runtime_error(fmt::format("Device selection failed: Device {} in user-provided list of devices caused error: {}", i, e.what()));
2✔
58
                        }
59
                }
60
                selected_devices = devices;
3✔
61
                how_selected = "specified by user";
3✔
62
        } else {
6✔
63
                if(std::all_of(platforms.cbegin(), platforms.cend(), [](auto& p) { return p.get_devices().empty(); })) {
114✔
64
                        throw std::runtime_error("Device selection failed: No devices available");
2✔
65
                }
66

67
                const auto select_all = [platforms](auto& selector) {
114✔
68
                        std::unordered_map<BackendT, std::vector<std::pair<DeviceT, size_t>>> scored_devices_by_backend;
59✔
69
                        for(size_t i = 0; i < platforms.size(); ++i) {
225!
70
                                const auto devices = platforms[i].get_devices(sycl::info::device_type::all);
83✔
71
                                for(size_t j = 0; j < devices.size(); ++j) {
378!
72
                                        try {
73
                                                check_required_device_aspects(devices[j]);
295✔
74
                                        } catch(std::runtime_error& e) {
28!
75
                                                CELERITY_TRACE("Ignoring device {} on platform {}: {}", j, i, e.what());
28✔
76
                                                continue;
14✔
77
                                        }
78
                                        const auto score = selector(devices[j]);
281✔
79
                                        if(score < 0) continue;
281!
80
                                        scored_devices_by_backend[devices[j].get_backend()].push_back(std::pair{devices[j], score});
261✔
81
                                }
82
                        }
83
                        size_t max_score = 0;
59✔
84
                        std::vector<DeviceT> max_score_devices;
59✔
85
                        for(auto& [backend, scored_devices] : scored_devices_by_backend) {
199!
86
                                size_t sum_score = 0;
70✔
87
                                std::vector<DeviceT> devices;
70✔
88
                                for(auto& [d, score] : scored_devices) {
331!
89
                                        sum_score += score;
261✔
90
                                        devices.push_back(d);
261✔
91
                                }
92
                                if(sum_score > max_score) {
70!
93
                                        max_score = sum_score;
57✔
94
                                        max_score_devices = std::move(devices);
57✔
95
                                }
96
                        }
97
                        return max_score_devices;
59✔
98
                };
59✔
99

100
                if constexpr(device_selector_provided) {
101
                        how_selected = "via user-provided selector";
9✔
102
                        selected_devices = select_all(user_devices_or_selector);
9✔
103
                } else {
104
                        how_selected = "automatically selected";
46✔
105
                        // First try to find eligible GPUs
106
                        const auto selector = [](const DeviceT& d) {
285✔
107
                                return d.template get_info<sycl::info::device::device_type>() == sycl::info::device_type::gpu ? 1 : -1;
239✔
108
                        };
109
                        selected_devices = select_all(selector);
46✔
110
                        if(selected_devices.empty()) {
46✔
111
                                // If none were found, fall back to other device types
112
                                const auto selector = [](const DeviceT& d) { return 1; };
14✔
113
                                selected_devices = select_all(selector);
4✔
114
                        }
115
                }
116

117
                if(selected_devices.empty()) { throw std::runtime_error("Device selection failed: No eligible devices found"); }
55✔
118
        }
55✔
119

120
        // When running with more than one local node, attempt to distribute devices evenly
121
        if(cfg.node_count > 1) {
55✔
122
                if(selected_devices.size() >= cfg.node_count) {
37✔
123
                        const size_t quotient = selected_devices.size() / cfg.node_count;
22✔
124
                        const size_t remainder = selected_devices.size() % cfg.node_count;
22✔
125

126
                        const auto rank = cfg.local_rank;
22✔
127
                        const size_t offset = rank < remainder ? rank * (quotient + 1) : remainder * (quotient + 1) + (rank - remainder) * quotient;
22✔
128
                        const size_t count = rank < remainder ? quotient + 1 : quotient;
22✔
129

130
                        std::vector<DeviceT> subset{selected_devices.begin() + offset, selected_devices.begin() + offset + count};
44✔
131
                        selected_devices = std::move(subset);
22✔
132
                } else {
22✔
133
                        CELERITY_WARN(
30✔
134
                            "Found fewer devices ({}) than local nodes ({}), multiple nodes will use the same device(s).", selected_devices.size(), cfg.node_count);
135
                        selected_devices = {selected_devices[cfg.local_rank % selected_devices.size()]};
30✔
136
                }
137
        }
138

139
        for(auto& device : selected_devices) {
243✔
140
                const auto platform_name = device.get_platform().template get_info<sycl::info::platform::name>();
94✔
141
                const auto device_name = device.template get_info<sycl::info::device::name>();
94✔
142
                CELERITY_INFO("Using platform '{}', device '{}' ({})", platform_name, device_name, how_selected);
188✔
143
        }
144

145
        return selected_devices;
55✔
146
}
71✔
147

148
/*
149
template<typename T>
150
concept BackendEnumerator = requires(const T &a) {
151
        typename T::backend_type;
152
        typename T::device_type;
153
        {a.compatible_backends(std::declval<typename T::device_type>)} -> std::same_as<std::vector<T::backend_type>>;
154
        {a.available_backends()} -> std::same_as<std::vector<T::backend_type>>;
155
        {a.is_specialized(std::declval<T::backend_type>())} -> std::same_as<bool>;
156
        {a.get_priority(std::declval<T::backend_type>())} -> std::same_as<int>;
157
};
158
*/
159

160
template <typename BackendEnumerator>
161
inline auto select_backend(const BackendEnumerator& enumerator, const std::vector<typename BackendEnumerator::device_type>& devices) {
3✔
162
        using backend_type = typename BackendEnumerator::backend_type;
163

164
        const auto available_backends = enumerator.available_backends();
3✔
165

166
        std::vector<backend_type> common_backends;
3✔
167
        for(auto& device : devices) {
15✔
168
                auto device_backends = enumerator.compatible_backends(device);
6✔
169
                common_backends = common_backends.empty() ? std::move(device_backends) : utils::set_intersection(common_backends, device_backends);
6✔
170
        }
171

172
        assert(!common_backends.empty());
3✔
173
        std::sort(common_backends.begin(), common_backends.end(),
3✔
174
            [&](const backend_type lhs, const backend_type rhs) { return enumerator.get_priority(lhs) > enumerator.get_priority(rhs); });
10✔
175

176
        for(const auto backend : common_backends) {
6!
177
                const auto is_specialized = enumerator.is_specialized(backend);
6✔
178
                if(utils::contains(available_backends, backend)) {
6✔
179
                        if(is_specialized) {
3✔
180
                                CELERITY_DEBUG("Using {} backend for the selected devices.", backend);
2✔
181
                        } else {
182
                                CELERITY_WARN("No common backend specialization available for all selected devices, falling back to {}. Performance may be degraded.", backend);
4✔
183
                        }
184
                        return backend;
6✔
185
                } else if(is_specialized) {
3!
186
                        CELERITY_WARN(
6✔
187
                            "All selected devices are compatible with specialized {} backend, but it has not been compiled. Performance may be degraded.", backend);
188
                }
189
        }
NEW
190
        utils::panic("no compatible backend available");
×
191
}
3✔
192

193
} // namespace celerity::detail
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc