• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

celerity / celerity-runtime / 8329571092

18 Mar 2024 03:49PM UTC coverage: 94.63% (+0.7%) from 93.968%
8329571092

push

github

fknorr
Update benchmark results for IDAG generation

2907 of 3248 branches covered (89.5%)

Branch coverage included in aggregate %.

6574 of 6771 relevant lines covered (97.09%)

179871.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.26
/include/task.h
1
#pragma once
2

3
#include <memory>
4
#include <unordered_map>
5
#include <unordered_set>
6
#include <utility>
7
#include <vector>
8

9
#include "device_queue.h"
10
#include "grid.h"
11
#include "hint.h"
12
#include "host_queue.h"
13
#include "intrusive_graph.h"
14
#include "launcher.h"
15
#include "lifetime_extending_state.h"
16
#include "range_mapper.h"
17
#include "types.h"
18

19
namespace celerity {
20

21
class handler;
22

23
namespace detail {
24

25
        enum class task_type {
26
                epoch,          ///< task epoch (graph-level serialization point)
27
                host_compute,   ///< host task with explicit global size and celerity-defined split
28
                device_compute, ///< device compute task
29
                collective,     ///< host task with implicit 1d global size = #ranks and fixed split
30
                master_node,    ///< zero-dimensional host task
31
                horizon,        ///< task horizon
32
                fence,          ///< promise-side of an async experimental::fence
33
        };
34

35
        enum class execution_target {
36
                none,
37
                host,
38
                device,
39
        };
40

41
        class command_launcher_storage_base {
42
          public:
43
                command_launcher_storage_base() = default;
4,410✔
44
                command_launcher_storage_base(const command_launcher_storage_base&) = delete;
45
                command_launcher_storage_base(command_launcher_storage_base&&) = default;
46
                command_launcher_storage_base& operator=(const command_launcher_storage_base&) = delete;
47
                command_launcher_storage_base& operator=(command_launcher_storage_base&&) = default;
48
                virtual ~command_launcher_storage_base() = default;
4,410✔
49

50
                virtual sycl::event operator()(
51
                    device_queue& q, const subrange<3> execution_sr, const std::vector<void*>& reduction_ptrs, const bool is_reduction_initializer) const = 0;
52
                virtual std::future<host_queue::execution_info> operator()(host_queue& q, const subrange<3>& execution_sr) const = 0;
53
        };
54

55
        template <typename Functor>
56
        class command_launcher_storage : public command_launcher_storage_base {
57
          public:
58
                command_launcher_storage(Functor&& fun) : m_fun(std::move(fun)) {}
4,410✔
59

60
                sycl::event operator()(
169✔
61
                    device_queue& q, const subrange<3> execution_sr, const std::vector<void*>& reduction_ptrs, const bool is_reduction_initializer) const override {
62
                        return invoke<sycl::event>(q, execution_sr, reduction_ptrs, is_reduction_initializer);
169✔
63
                }
64

65
                std::future<host_queue::execution_info> operator()(host_queue& q, const subrange<3>& execution_sr) const override {
3,158✔
66
                        return invoke<std::future<host_queue::execution_info>>(q, execution_sr);
3,158✔
67
                }
68

69
          private:
70
                Functor m_fun;
71

72
                template <typename Ret, typename... Args>
73
                Ret invoke(Args&&... args) const {
3,327✔
74
                        if constexpr(std::is_invocable_v<Functor, Args...>) {
75
                                return m_fun(args...);
3,327✔
76
                        } else {
77
                                throw std::runtime_error("Cannot launch command function with provided arguments");
×
78
                        }
79
                }
80
        };
81

82
        class buffer_access_map {
83
          public:
84
                void add_access(buffer_id bid, std::unique_ptr<range_mapper_base>&& rm) { m_accesses.emplace_back(bid, std::move(rm)); }
3,573✔
85

86
                std::unordered_set<buffer_id> get_accessed_buffers() const;
87
                std::unordered_set<cl::sycl::access::mode> get_access_modes(buffer_id bid) const;
88
                size_t get_num_accesses() const { return m_accesses.size(); }
22,479✔
89
                std::pair<buffer_id, access_mode> get_nth_access(const size_t n) const {
2,668✔
90
                        const auto& [bid, rm] = m_accesses[n];
2,668✔
91
                        return {bid, rm->get_access_mode()};
2,668✔
92
                }
93

94
                /**
95
                 * @brief Computes the combined access-region for a given buffer, mode and subrange.
96
                 *
97
                 * @param bid
98
                 * @param mode
99
                 * @param sr The subrange to be passed to the range mappers (extended to a chunk using the global size of the task)
100
                 *
101
                 * @returns The region obtained by merging the results of all range-mappers for this buffer and mode
102
                 */
103
                region<3> get_mode_requirements(
104
                    const buffer_id bid, const access_mode mode, const int kernel_dims, const subrange<3>& sr, const range<3>& global_size) const;
105

106
                box<3> get_requirements_for_nth_access(const size_t n, const int kernel_dims, const subrange<3>& sr, const range<3>& global_size) const;
107

108
                std::vector<const range_mapper_base*> get_range_mappers(const buffer_id bid) const {
109
                        std::vector<const range_mapper_base*> rms;
110
                        for(const auto& [a_bid, a_rm] : m_accesses) {
111
                                if(a_bid == bid) { rms.push_back(a_rm.get()); }
112
                        }
113
                        return rms;
114
                }
115

116
                box_vector<3> get_required_contiguous_boxes(const buffer_id bid, const int kernel_dims, const subrange<3>& sr, const range<3>& global_size) const;
117

118
          private:
119
                std::vector<std::pair<buffer_id, std::unique_ptr<range_mapper_base>>> m_accesses;
120
        };
121

122
        using reduction_set = std::vector<reduction_info>;
123

124
        class side_effect_map : private std::unordered_map<host_object_id, experimental::side_effect_order> {
125
          private:
126
                using map_base = std::unordered_map<host_object_id, experimental::side_effect_order>;
127

128
          public:
129
                using typename map_base::const_iterator, map_base::value_type, map_base::key_type, map_base::mapped_type, map_base::const_reference,
130
                    map_base::const_pointer;
131
                using iterator = const_iterator;
132
                using reference = const_reference;
133
                using pointer = const_pointer;
134

135
                using map_base::size, map_base::count, map_base::empty, map_base::cbegin, map_base::cend, map_base::at;
136

137
                iterator begin() const { return cbegin(); }
6,793✔
138
                iterator end() const { return cend(); }
6,791✔
139
                iterator find(host_object_id key) const { return map_base::find(key); }
140

141
                void add_side_effect(host_object_id hoid, experimental::side_effect_order order);
142
        };
143

144
        class fence_promise {
145
          public:
146
                fence_promise() = default;
42✔
147
                fence_promise(const fence_promise&) = delete;
148
                fence_promise& operator=(const fence_promise&) = delete;
149
                virtual ~fence_promise() = default;
42✔
150

151
                virtual void fulfill() = 0;
152
                virtual allocation_id get_user_allocation_id() = 0;
153
        };
154

155
        struct task_geometry {
156
                int dimensions = 0;
157
                range<3> global_size{1, 1, 1};
158
                id<3> global_offset{};
159
                range<3> granularity{1, 1, 1};
160
        };
161

162
        class task : public intrusive_graph_node<task> {
163
          public:
164
                task_type get_type() const { return m_type; }
37,259✔
165

166
                task_id get_id() const { return m_tid; }
48,294✔
167

168
                collective_group_id get_collective_group_id() const { return m_cgid; }
14,280✔
169

170
                const buffer_access_map& get_buffer_access_map() const { return m_access_map; }
43,415✔
171

172
                const side_effect_map& get_side_effect_map() const { return m_side_effects; }
15,990✔
173

174
                const task_geometry& get_geometry() const { return m_geometry; }
3,941✔
175

176
                int get_dimensions() const { return m_geometry.dimensions; }
56,883✔
177

178
                range<3> get_global_size() const { return m_geometry.global_size; }
96,762✔
179

180
                id<3> get_global_offset() const { return m_geometry.global_offset; }
28,443✔
181

182
                range<3> get_granularity() const { return m_geometry.granularity; }
7,459✔
183

184
                void set_debug_name(const std::string& debug_name) { m_debug_name = debug_name; }
4,410✔
185
                const std::string& get_debug_name() const { return m_debug_name; }
4,923✔
186

187
                bool has_variable_split() const { return m_type == task_type::host_compute || m_type == task_type::device_compute; }
5,144✔
188

189
                execution_target get_execution_target() const {
9,067✔
190
                        switch(m_type) {
9,067!
191
                        case task_type::epoch: return execution_target::none;
×
192
                        case task_type::device_compute: return execution_target::device;
2,077✔
193
                        case task_type::host_compute:
6,990✔
194
                        case task_type::collective:
195
                        case task_type::master_node: return execution_target::host;
6,990✔
196
                        case task_type::horizon:
×
197
                        case task_type::fence: return execution_target::none;
×
198
                        default: assert(!"Unhandled task type"); return execution_target::none;
×
199
                        }
200
                }
201

202
                const reduction_set& get_reductions() const { return m_reductions; }
53,023✔
203

204
                epoch_action get_epoch_action() const { return m_epoch_action; }
1,420✔
205

206
                fence_promise* get_fence_promise() const { return m_fence_promise.get(); }
69✔
207

208
                template <typename Launcher>
209
                Launcher get_launcher() const {
466✔
210
                        return {};
466✔
211
                } // placeholder
212

213
                template <typename... Args>
214
                auto launch(Args&&... args) const {
3,705✔
215
                        return (*m_launcher)(std::forward<Args>(args)...);
3,705✔
216
                }
217

218
                void extend_lifetime(std::shared_ptr<lifetime_extending_state> state) { m_attached_state.emplace_back(std::move(state)); }
2,397✔
219

220
                void add_hint(std::unique_ptr<hint_base>&& h) { m_hints.emplace_back(std::move(h)); }
46✔
221

222
                template <typename Hint>
223
                const Hint* get_hint() const {
3,397✔
224
                        static_assert(std::is_base_of_v<hint_base, Hint>, "Hint must extend hint_base");
225
                        for(auto& h : m_hints) {
3,484✔
226
                                if(auto* ptr = dynamic_cast<Hint*>(h.get()); ptr != nullptr) { return ptr; }
139!
227
                        }
228
                        return nullptr;
3,345✔
229
                }
230

231
                static std::unique_ptr<task> make_epoch(task_id tid, detail::epoch_action action) {
1,921✔
232
                        return std::unique_ptr<task>(new task(tid, task_type::epoch, non_collective_group_id, task_geometry{}, nullptr, {}, {}, {}, action, nullptr));
5,763✔
233
                }
234

235
                static std::unique_ptr<task> make_host_compute(task_id tid, task_geometry geometry, std::unique_ptr<command_launcher_storage_base> launcher,
73✔
236
                    buffer_access_map access_map, side_effect_map side_effect_map, reduction_set reductions) {
237
                        return std::unique_ptr<task>(new task(tid, task_type::host_compute, non_collective_group_id, geometry, std::move(launcher), std::move(access_map),
146✔
238
                            std::move(side_effect_map), std::move(reductions), {}, nullptr));
219✔
239
                }
240

241
                static std::unique_ptr<task> make_device_compute(task_id tid, task_geometry geometry, std::unique_ptr<command_launcher_storage_base> launcher,
957✔
242
                    buffer_access_map access_map, reduction_set reductions) {
243
                        return std::unique_ptr<task>(new task(tid, task_type::device_compute, non_collective_group_id, geometry, std::move(launcher), std::move(access_map),
1,914✔
244
                            {}, std::move(reductions), {}, nullptr));
2,871✔
245
                }
246

247
                static std::unique_ptr<task> make_collective(task_id tid, collective_group_id cgid, size_t num_collective_nodes,
59✔
248
                    std::unique_ptr<command_launcher_storage_base> launcher, buffer_access_map access_map, side_effect_map side_effect_map) {
249
                        const task_geometry geometry{1, detail::range_cast<3>(range(num_collective_nodes)), {}, {1, 1, 1}};
59✔
250
                        return std::unique_ptr<task>(
251
                            new task(tid, task_type::collective, cgid, geometry, std::move(launcher), std::move(access_map), std::move(side_effect_map), {}, {}, nullptr));
118✔
252
                }
253

254
                static std::unique_ptr<task> make_master_node(
3,321✔
255
                    task_id tid, std::unique_ptr<command_launcher_storage_base> launcher, buffer_access_map access_map, side_effect_map side_effect_map) {
256
                        return std::unique_ptr<task>(new task(tid, task_type::master_node, non_collective_group_id, task_geometry{}, std::move(launcher),
16,605✔
257
                            std::move(access_map), std::move(side_effect_map), {}, {}, nullptr));
13,284✔
258
                }
259

260
                static std::unique_ptr<task> make_horizon(task_id tid) {
571✔
261
                        return std::unique_ptr<task>(new task(tid, task_type::horizon, non_collective_group_id, task_geometry{}, nullptr, {}, {}, {}, {}, nullptr));
1,713✔
262
                }
263

264
                static std::unique_ptr<task> make_fence(
60✔
265
                    task_id tid, buffer_access_map access_map, side_effect_map side_effect_map, std::unique_ptr<fence_promise> fence_promise) {
266
                        return std::unique_ptr<task>(new task(tid, task_type::fence, non_collective_group_id, task_geometry{}, nullptr, std::move(access_map),
300✔
267
                            std::move(side_effect_map), {}, {}, std::move(fence_promise)));
240✔
268
                }
269

270
          private:
271
                task_id m_tid;
272
                task_type m_type;
273
                collective_group_id m_cgid;
274
                task_geometry m_geometry;
275
                std::unique_ptr<command_launcher_storage_base> m_launcher;
276
                buffer_access_map m_access_map;
277
                detail::side_effect_map m_side_effects;
278
                reduction_set m_reductions;
279
                std::string m_debug_name;
280
                detail::epoch_action m_epoch_action;
281
                // TODO I believe that `struct task` should not store command_group_launchers, fence_promise or other state that is related to execution instead of
282
                // abstract DAG building. For user-initialized buffers we already notify the runtime -> executor of this state directly. Maybe also do that for these.
283
                std::unique_ptr<fence_promise> m_fence_promise;
284
                std::vector<std::shared_ptr<lifetime_extending_state>> m_attached_state;
285
                std::vector<std::unique_ptr<hint_base>> m_hints;
286

287
                task(task_id tid, task_type type, collective_group_id cgid, task_geometry geometry, std::unique_ptr<command_launcher_storage_base> launcher,
6,794✔
288
                    buffer_access_map access_map, detail::side_effect_map side_effects, reduction_set reductions, detail::epoch_action epoch_action,
289
                    std::unique_ptr<fence_promise> fence_promise)
290
                    : m_tid(tid), m_type(type), m_cgid(cgid), m_geometry(geometry), m_launcher(std::move(launcher)), m_access_map(std::move(access_map)),
6,794✔
291
                      m_side_effects(std::move(side_effects)), m_reductions(std::move(reductions)), m_epoch_action(epoch_action),
6,794✔
292
                      m_fence_promise(std::move(fence_promise)) {
13,588✔
293
                        assert(type == task_type::host_compute || type == task_type::device_compute || get_granularity().size() == 1);
6,794✔
294
                        // Only host tasks can have side effects
295
                        assert(this->m_side_effects.empty() || type == task_type::host_compute || type == task_type::collective || type == task_type::master_node
6,794✔
296
                               || type == task_type::fence);
297
                }
6,794✔
298
        };
299

300
        [[nodiscard]] std::string print_task_debug_label(const task& tsk, bool title_case = false);
301

302
        /// Determines which overlapping regions appear between write accesses when the iteration space of `tsk` is split into `chunks`.
303
        std::unordered_map<buffer_id, region<3>> detect_overlapping_writes(const task& tsk, const box_vector<3>& chunks);
304

305
} // namespace detail
306
} // namespace celerity
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc