• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

celerity / celerity-runtime / 9908434959

12 Jul 2024 01:10PM UTC coverage: 93.535% (-1.1%) from 94.676%
9908434959

Pull #255

github

fknorr
[RM] improve logging specificity on idag-device-selection
Pull Request #255: [IDAG] Upstream Multi-Device Selection + Tests

3134 of 3561 branches covered (88.01%)

Branch coverage included in aggregate %.

107 of 109 new or added lines in 2 files covered. (98.17%)

65 existing lines in 1 file now uncovered.

7095 of 7375 relevant lines covered (96.2%)

282239.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.38
/include/device_selection.h
1
#pragma once
2

3
#include <functional>
4
#include <string>
5
#include <type_traits>
6
#include <variant>
7
#include <vector>
8

9
#include "config.h"
10
#include "log.h"
11

12
#include <sycl/sycl.hpp>
13

14
namespace celerity::detail {
15

16
// TODO these are required by distr_queue.h, but we don't want to pull all include dependencies of the pick_devices implementation into user code!
17
struct auto_select_devices {};
18
using device_selector = std::function<int(const sycl::device&)>;
19
using devices_or_selector = std::variant<auto_select_devices, std::vector<sycl::device>, device_selector>;
20

21
template <typename DeviceT>
22
void check_required_device_aspects(const DeviceT& device) {
309✔
23
        if(!device.has(sycl::aspect::usm_device_allocations)) { throw std::runtime_error("device does not support USM device allocations"); }
309✔
24
        if(!device.has(sycl::aspect::usm_host_allocations)) { throw std::runtime_error("device does not support USM host allocations"); }
293✔
25
}
292✔
26

27
template <typename DevicesOrSelector, typename PlatformT>
28
auto pick_devices(const host_config& cfg, const DevicesOrSelector& user_devices_or_selector, const std::vector<PlatformT>& platforms) {
63✔
29
        using DeviceT = typename decltype(std::declval<PlatformT&>().get_devices())::value_type;
30
        using BackendT = decltype(std::declval<DeviceT&>().get_backend());
31

32
        constexpr bool user_devices_provided = std::is_same_v<DevicesOrSelector, std::vector<DeviceT>>;
63✔
33
        constexpr bool device_selector_provided = std::is_invocable_r_v<int, DevicesOrSelector, DeviceT>;
63✔
34
        constexpr bool auto_select = std::is_same_v<auto_select_devices, DevicesOrSelector>;
63✔
35
        static_assert(user_devices_provided ^ device_selector_provided ^ auto_select,
36
            "pick_device requires either a list of devices, a selector, or the auto_select_devices tag");
37

38
        std::vector<DeviceT> selected_devices;
63✔
39
        std::string how_selected;
63✔
40

41
        if(cfg.node_count > 1) {
63✔
42
                CELERITY_WARN("Celerity detected more than one node (MPI rank) on this host, which is not recommended. Will attempt to distribute local devices evenly "
74✔
43
                              "across nodes.");
44
        }
45

46
        if constexpr(user_devices_provided) {
47
                const auto devices = user_devices_or_selector;
6✔
48
                if(devices.empty()) { throw std::runtime_error("Device selection failed: The user-provided list of devices is empty"); }
6✔
49
                auto backend = devices[0].get_backend();
5✔
50
                for(size_t i = 0; i < devices.size(); ++i) {
15✔
51
                        if(devices[i].get_backend() != backend) {
12✔
52
                                throw std::runtime_error("Device selection failed: The user-provided list of devices contains devices from different backends");
1✔
53
                        }
54
                        try {
55
                                check_required_device_aspects(devices[i]);
11✔
56
                        } catch(std::runtime_error& e) {
2!
57
                                throw std::runtime_error(fmt::format("Device selection failed: Device {} in user-provided list of devices caused error: {}", i, e.what()));
2✔
58
                        }
59
                }
60
                selected_devices = devices;
3✔
61
                how_selected = "specified by user";
3✔
62
        } else {
6✔
63
                if(std::all_of(platforms.cbegin(), platforms.cend(), [](auto& p) { return p.get_devices().empty(); })) {
114✔
64
                        throw std::runtime_error("Device selection failed: No devices available");
2✔
65
                }
66

67
                const auto select_all = [platforms](auto& selector) {
114✔
68
                        std::unordered_map<BackendT, std::vector<std::pair<DeviceT, size_t>>> scored_devices_by_backend;
59✔
69
                        for(size_t i = 0; i < platforms.size(); ++i) {
225!
70
                                const auto devices = platforms[i].get_devices(sycl::info::device_type::all);
83✔
71
                                for(size_t j = 0; j < devices.size(); ++j) {
378!
72
                                        try {
73
                                                check_required_device_aspects(devices[j]);
295✔
74
                                        } catch(std::runtime_error& e) {
28!
75
                                                CELERITY_TRACE("Ignoring platform {} \"{}\", device {} \"{}\": {}", i, platforms[i].template get_info<sycl::info::platform::name>(), j,
28✔
76
                                                    devices[j].template get_info<sycl::info::device::name>(), e.what());
77
                                                continue;
14✔
78
                                        }
79
                                        const auto score = selector(devices[j]);
281✔
80
                                        if(score < 0) continue;
281!
81
                                        scored_devices_by_backend[devices[j].get_backend()].push_back(std::pair{devices[j], score});
261✔
82
                                }
83
                        }
84
                        size_t max_score = 0;
59✔
85
                        std::vector<DeviceT> max_score_devices;
59✔
86
                        for(auto& [backend, scored_devices] : scored_devices_by_backend) {
199!
87
                                size_t sum_score = 0;
70✔
88
                                std::vector<DeviceT> devices;
70✔
89
                                for(auto& [d, score] : scored_devices) {
331!
90
                                        sum_score += score;
261✔
91
                                        devices.push_back(d);
261✔
92
                                }
93
                                if(sum_score > max_score) {
70!
94
                                        max_score = sum_score;
57✔
95
                                        max_score_devices = std::move(devices);
57✔
96
                                }
97
                        }
98
                        return max_score_devices;
59✔
99
                };
59✔
100

101
                if constexpr(device_selector_provided) {
102
                        how_selected = "via user-provided selector";
9✔
103
                        selected_devices = select_all(user_devices_or_selector);
9✔
104
                } else {
105
                        how_selected = "automatically selected";
46✔
106
                        // First try to find eligible GPUs
107
                        const auto selector = [](const DeviceT& d) {
285✔
108
                                return d.template get_info<sycl::info::device::device_type>() == sycl::info::device_type::gpu ? 1 : -1;
239✔
109
                        };
110
                        selected_devices = select_all(selector);
46✔
111
                        if(selected_devices.empty()) {
46✔
112
                                // If none were found, fall back to other device types
113
                                const auto selector = [](const DeviceT& d) { return 1; };
14✔
114
                                selected_devices = select_all(selector);
4✔
115
                        }
116
                }
117

118
                if(selected_devices.empty()) { throw std::runtime_error("Device selection failed: No eligible devices found"); }
55✔
119
        }
55✔
120

121
        // When running with more than one local node, attempt to distribute devices evenly
122
        if(cfg.node_count > 1) {
55✔
123
                if(selected_devices.size() >= cfg.node_count) {
37✔
124
                        const size_t quotient = selected_devices.size() / cfg.node_count;
22✔
125
                        const size_t remainder = selected_devices.size() % cfg.node_count;
22✔
126

127
                        const auto rank = cfg.local_rank;
22✔
128
                        const size_t offset = rank < remainder ? rank * (quotient + 1) : remainder * (quotient + 1) + (rank - remainder) * quotient;
22✔
129
                        const size_t count = rank < remainder ? quotient + 1 : quotient;
22✔
130

131
                        std::vector<DeviceT> subset{selected_devices.begin() + offset, selected_devices.begin() + offset + count};
44✔
132
                        selected_devices = std::move(subset);
22✔
133
                } else {
22✔
134
                        CELERITY_WARN(
30✔
135
                            "Found fewer devices ({}) than local nodes ({}), multiple nodes will use the same device(s).", selected_devices.size(), cfg.node_count);
136
                        selected_devices = {selected_devices[cfg.local_rank % selected_devices.size()]};
30✔
137
                }
138
        }
139

140
        for(device_id did = 0; did < selected_devices.size(); ++did) {
243✔
141
                const auto platform_name = selected_devices[did].get_platform().template get_info<sycl::info::platform::name>();
94✔
142
                const auto device_name = selected_devices[did].template get_info<sycl::info::device::name>();
94✔
143
                CELERITY_INFO("Using platform \"{}\", device \"{}\" as D{} ({})", platform_name, device_name, did, how_selected);
188✔
144
        }
145

146
        return selected_devices;
55✔
147
}
71✔
148

149
/*
150
template<typename T>
151
concept BackendEnumerator = requires(const T &a) {
152
    typename T::backend_type;
153
    typename T::device_type;
154
    {a.compatible_backends(std::declval<typename T::device_type>)} -> std::same_as<std::vector<T::backend_type>>;
155
    {a.available_backends()} -> std::same_as<std::vector<T::backend_type>>;
156
    {a.is_specialized(std::declval<T::backend_type>())} -> std::same_as<bool>;
157
    {a.get_priority(std::declval<T::backend_type>())} -> std::same_as<int>;
158
};
159
*/
160

161
template <typename BackendEnumerator>
162
inline auto select_backend(const BackendEnumerator& enumerator, const std::vector<typename BackendEnumerator::device_type>& devices) {
3✔
163
        using backend_type = typename BackendEnumerator::backend_type;
164

165
        const auto available_backends = enumerator.available_backends();
3✔
166

167
        std::vector<backend_type> common_backends;
3✔
168
        for(auto& device : devices) {
15✔
169
                auto device_backends = enumerator.compatible_backends(device);
6✔
170
                common_backends = common_backends.empty() ? std::move(device_backends) : utils::set_intersection(common_backends, device_backends);
6✔
171
        }
172

173
        assert(!common_backends.empty());
3✔
174
        std::sort(common_backends.begin(), common_backends.end(),
3✔
175
            [&](const backend_type lhs, const backend_type rhs) { return enumerator.get_priority(lhs) > enumerator.get_priority(rhs); });
10✔
176

177
        for(const auto backend : common_backends) {
6!
178
                const auto is_specialized = enumerator.is_specialized(backend);
6✔
179
                if(utils::contains(available_backends, backend)) {
6✔
180
                        if(is_specialized) {
3✔
181
                                CELERITY_DEBUG("Using {} backend for the selected devices.", backend);
2✔
182
                        } else {
183
                                CELERITY_WARN("No common backend specialization available for all selected devices, falling back to {}. Performance may be degraded.", backend);
4✔
184
                        }
185
                        return backend;
6✔
186
                } else if(is_specialized) {
3!
187
                        CELERITY_WARN(
6✔
188
                            "All selected devices are compatible with specialized {} backend, but it has not been compiled. Performance may be degraded.", backend);
189
                }
190
        }
NEW
191
        utils::panic("no compatible backend available");
×
192
}
3✔
193

194
} // namespace celerity::detail
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc