• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

celerity / celerity-runtime / 12009901531

25 Nov 2024 12:20PM UTC coverage: 94.92% (+0.009%) from 94.911%
12009901531

push

github

fknorr
Add missing includes and consistently order them

We can't add the misc-include-cleaner lint because it causes too many
false positives with "interface headers" such as sycl.hpp.

3190 of 3626 branches covered (87.98%)

Branch coverage included in aggregate %.

7049 of 7161 relevant lines covered (98.44%)

1242183.17 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.1
/include/device_selection.h
1
#pragma once
2

3
#include "config.h"
4
#include "log.h"
5
#include "types.h"
6
#include "utils.h"
7

8
#include <cassert>
9
#include <concepts>
10
#include <cstddef>
11
#include <functional>
12
#include <stdexcept>
13
#include <string>
14
#include <type_traits>
15
#include <unordered_map>
16
#include <utility>
17
#include <variant>
18
#include <vector>
19

20
#include <fmt/format.h>
21
#include <sycl/sycl.hpp>
22

23

24
namespace celerity::detail {
25

26
// TODO these are required by distr_queue.h, but we don't want to pull all include dependencies of the pick_devices implementation into user code!
27
struct auto_select_devices {};
28
using device_selector = std::function<int(const sycl::device&)>;
29
using devices_or_selector = std::variant<auto_select_devices, std::vector<sycl::device>, device_selector>;
30

31
template <typename DeviceT>
32
void check_required_device_aspects(const DeviceT& device) {
1,193✔
33
        if(!device.has(sycl::aspect::usm_device_allocations)) { throw std::runtime_error("device does not support USM device allocations"); }
1,193✔
34
        if(!device.has(sycl::aspect::usm_host_allocations)) { throw std::runtime_error("device does not support USM host allocations"); }
1,177✔
35
}
1,176✔
36

37
template <typename DevicesOrSelector, typename PlatformT>
38
auto pick_devices(const host_config& cfg, const DevicesOrSelector& user_devices_or_selector, const std::vector<PlatformT>& platforms) {
293✔
39
        using DeviceT = typename decltype(std::declval<PlatformT&>().get_devices())::value_type;
40
        using BackendT = decltype(std::declval<DeviceT&>().get_backend());
41

42
        constexpr bool user_devices_provided = std::is_same_v<DevicesOrSelector, std::vector<DeviceT>>;
293✔
43
        constexpr bool device_selector_provided = std::is_invocable_r_v<int, DevicesOrSelector, DeviceT>;
293✔
44
        constexpr bool auto_select = std::is_same_v<auto_select_devices, DevicesOrSelector>;
293✔
45
        static_assert(user_devices_provided ^ device_selector_provided ^ auto_select,
46
            "pick_device requires either a list of devices, a selector, or the auto_select_devices tag");
47

48
        std::vector<DeviceT> selected_devices;
293✔
49
        std::string how_selected;
293✔
50

51
        if(cfg.node_count > 1) {
293✔
52
                CELERITY_WARN("Celerity detected more than one node (MPI rank) on this host, which is not recommended. Will attempt to distribute local devices evenly "
350✔
53
                              "across nodes.");
54
        }
55

56
        if constexpr(user_devices_provided) {
57
                const auto devices = user_devices_or_selector;
18✔
58
                if(devices.empty()) { throw std::runtime_error("Device selection failed: The user-provided list of devices is empty"); }
18✔
59
                auto backend = devices[0].get_backend();
17✔
60
                for(size_t i = 0; i < devices.size(); ++i) {
39✔
61
                        if(devices[i].get_backend() != backend) {
24✔
62
                                throw std::runtime_error("Device selection failed: The user-provided list of devices contains devices from different backends");
1✔
63
                        }
64
                        try {
65
                                check_required_device_aspects(devices[i]);
23✔
66
                        } catch(std::runtime_error& e) {
2!
67
                                throw std::runtime_error(fmt::format("Device selection failed: Device {} in user-provided list of devices caused error: {}", i, e.what()));
2✔
68
                        }
69
                }
70
                selected_devices = devices;
15✔
71
                how_selected = "specified by user";
15✔
72
        } else {
18✔
73
                if(std::all_of(platforms.cbegin(), platforms.cend(), [](auto& p) { return p.get_devices().empty(); })) {
550✔
74
                        throw std::runtime_error("Device selection failed: No devices available");
2✔
75
                }
76

77
                const auto select_all = [platforms](auto& selector) {
550✔
78
                        std::unordered_map<BackendT, std::vector<std::pair<DeviceT, size_t>>> scored_devices_by_backend;
277✔
79
                        for(size_t i = 0; i < platforms.size(); ++i) {
879✔
80
                                const auto devices = platforms[i].get_devices(sycl::info::device_type::all);
301✔
81
                                for(size_t j = 0; j < devices.size(); ++j) {
1,468✔
82
                                        try {
83
                                                check_required_device_aspects(devices[j]);
1,167✔
84
                                        } catch(std::runtime_error& e) {
28!
85
                                                CELERITY_TRACE("Ignoring platform {} \"{}\", device {} \"{}\": {}", i, platforms[i].template get_info<sycl::info::platform::name>(), j,
14✔
86
                                                    devices[j].template get_info<sycl::info::device::name>(), e.what());
87
                                                continue;
14✔
88
                                        }
89
                                        const auto score = selector(devices[j]);
1,153✔
90
                                        if(score < 0) continue;
1,153!
91
                                        scored_devices_by_backend[devices[j].get_backend()].push_back(std::pair{devices[j], score});
1,133✔
92
                                }
93
                        }
94
                        size_t max_score = 0;
277✔
95
                        std::vector<DeviceT> max_score_devices;
277✔
96
                        for(auto& [backend, scored_devices] : scored_devices_by_backend) {
853!
97
                                size_t sum_score = 0;
288✔
98
                                std::vector<DeviceT> devices;
288✔
99
                                for(auto& [d, score] : scored_devices) {
1,421!
100
                                        sum_score += score;
1,133✔
101
                                        devices.push_back(d);
1,133✔
102
                                }
103
                                if(sum_score > max_score) {
288!
104
                                        max_score = sum_score;
275✔
105
                                        max_score_devices = std::move(devices);
275✔
106
                                }
107
                        }
108
                        return max_score_devices;
277✔
109
                };
277✔
110

111
                if constexpr(device_selector_provided) {
112
                        how_selected = "via user-provided selector";
9✔
113
                        selected_devices = select_all(user_devices_or_selector);
9✔
114
                } else {
115
                        how_selected = "automatically selected";
264✔
116
                        // First try to find eligible GPUs
117
                        const auto selector = [](const DeviceT& d) {
1,375✔
118
                                return d.template get_info<sycl::info::device::device_type>() == sycl::info::device_type::gpu ? 1 : -1;
1,111✔
119
                        };
120
                        selected_devices = select_all(selector);
264✔
121
                        if(selected_devices.empty()) {
264✔
122
                                // If none were found, fall back to other device types
123
                                const auto selector = [](const DeviceT& d) { return 1; };
14✔
124
                                selected_devices = select_all(selector);
4✔
125
                        }
126
                }
127

128
                if(selected_devices.empty()) { throw std::runtime_error("Device selection failed: No eligible devices found"); }
273✔
129
        }
273✔
130

131
        // When running with more than one local node, attempt to distribute devices evenly
132
        if(cfg.node_count > 1) {
285✔
133
                if(selected_devices.size() >= cfg.node_count) {
175✔
134
                        const size_t quotient = selected_devices.size() / cfg.node_count;
160✔
135
                        const size_t remainder = selected_devices.size() % cfg.node_count;
160✔
136

137
                        const auto rank = cfg.local_rank;
160✔
138
                        const size_t offset = rank < remainder ? rank * (quotient + 1) : remainder * (quotient + 1) + (rank - remainder) * quotient;
160✔
139
                        const size_t count = rank < remainder ? quotient + 1 : quotient;
160✔
140

141
                        std::vector<DeviceT> subset{selected_devices.begin() + offset, selected_devices.begin() + offset + count};
480✔
142
                        selected_devices = std::move(subset);
160✔
143
                } else {
160✔
144
                        CELERITY_WARN(
15✔
145
                            "Found fewer devices ({}) than local nodes ({}), multiple nodes will use the same device(s).", selected_devices.size(), cfg.node_count);
146
                        selected_devices = {selected_devices[cfg.local_rank % selected_devices.size()]};
60!
147
                }
148
        }
149

150
        for(device_id did = 0; did < selected_devices.size(); ++did) {
1,505✔
151
                const auto platform_name = selected_devices[did].get_platform().template get_info<sycl::info::platform::name>();
610✔
152
                const auto device_name = selected_devices[did].template get_info<sycl::info::device::name>();
610✔
153
                CELERITY_INFO("Using platform \"{}\", device \"{}\" as D{} ({})", platform_name, device_name, did, how_selected);
610!
154
        }
155

156
        return selected_devices;
285✔
157
}
331✔
158

159
template <typename T>
160
concept BackendEnumerator = requires(const T& a) {
161
        typename T::backend_type;
162
        typename T::device_type;
163
        { a.compatible_backends(std::declval<typename T::device_type>()) } -> std::same_as<std::vector<typename T::backend_type>>;
164
        { a.available_backends() } -> std::same_as<std::vector<typename T::backend_type>>;
165
        { a.is_specialized(std::declval<typename T::backend_type>()) } -> std::same_as<bool>;
166
        { a.get_priority(std::declval<typename T::backend_type>()) } -> std::same_as<int>;
167
};
168

169
template <BackendEnumerator E>
170
inline auto select_backend(const E& enumerator, const std::vector<typename E::device_type>& devices) {
233✔
171
        using backend_type = typename E::backend_type;
172

173
        const auto available_backends = enumerator.available_backends();
233✔
174

175
        std::vector<backend_type> common_backends;
233✔
176
        for(auto& device : devices) {
1,277✔
177
                auto device_backends = enumerator.compatible_backends(device);
522✔
178
                common_backends = common_backends.empty() ? std::move(device_backends) : utils::set_intersection(common_backends, device_backends);
522✔
179
        }
180

181
        assert(!common_backends.empty());
233✔
182
        std::sort(common_backends.begin(), common_backends.end(),
233✔
183
            [&](const backend_type lhs, const backend_type rhs) { return enumerator.get_priority(lhs) > enumerator.get_priority(rhs); });
10✔
184

185
        for(const auto backend : common_backends) {
236!
186
                const auto is_specialized = enumerator.is_specialized(backend);
236✔
187
                if(utils::contains(available_backends, backend)) {
236✔
188
                        if(is_specialized) {
233✔
189
                                CELERITY_DEBUG("Using {} backend for the selected devices.", backend);
1!
190
                        } else {
191
                                CELERITY_WARN("No common backend specialization available for all selected devices, falling back to {}. Performance may be degraded.", backend);
232!
192
                        }
193
                        return backend;
466✔
194
                } else if(is_specialized) {
3!
195
                        CELERITY_WARN(
3!
196
                            "All selected devices are compatible with specialized {} backend, but it has not been compiled. Performance may be degraded.", backend);
197
                }
198
        }
199
        utils::panic("no compatible backend available");
×
200
}
233✔
201

202
} // namespace celerity::detail
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc