• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

celerity / celerity-runtime / 11251914715

09 Oct 2024 09:13AM UTC coverage: 95.051% (-0.05%) from 95.102%
11251914715

push

github

fknorr
Update changelog for new queue APIs

3021 of 3426 branches covered (88.18%)

Branch coverage included in aggregate %.

6659 of 6758 relevant lines covered (98.54%)

1492206.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.83
/src/task_manager.cc
1
#include "task_manager.h"
2

3
#include "access_modes.h"
4
#include "log.h"
5
#include "recorders.h"
6

7
namespace celerity {
8
namespace detail {
9

10
        task_manager::task_manager(size_t num_collective_nodes, detail::task_recorder* recorder, const policy_set& error_policy)
526✔
11
            : m_num_collective_nodes(num_collective_nodes), m_policy(error_policy), m_task_recorder(recorder) {
526✔
12
                // We manually generate the initial epoch task, which we treat as if it has been reached immediately.
13
                auto reserve = m_task_buffer.reserve_task_entry(await_free_task_slot_callback());
526✔
14
                auto initial_epoch = task::make_epoch(initial_epoch_task, epoch_action::none);
526✔
15
                if(m_task_recorder != nullptr) { m_task_recorder->record(task_record(*initial_epoch, {})); }
526✔
16
                m_task_buffer.put(std::move(reserve), std::move(initial_epoch));
526✔
17
        }
1,052✔
18

19
        void task_manager::notify_buffer_created(const buffer_id bid, const range<3>& range, const bool host_initialized) {
643✔
20
                const auto [iter, inserted] = m_buffers.emplace(bid, range);
643✔
21
                assert(inserted);
643✔
22
                auto& buffer = iter->second;
643✔
23
                if(host_initialized) { buffer.last_writers.update_region(subrange<3>({}, range), m_epoch_for_new_tasks); }
887✔
24
        }
643✔
25

26
        void task_manager::notify_buffer_debug_name_changed(const buffer_id bid, const std::string& debug_name) { m_buffers.at(bid).debug_name = debug_name; }
23✔
27

28
        void task_manager::notify_buffer_destroyed(const buffer_id bid) {
493✔
29
                assert(m_buffers.count(bid) != 0);
493✔
30
                m_buffers.erase(bid);
493✔
31
        }
493✔
32
        void task_manager::notify_host_object_created(const host_object_id hoid) { m_host_objects.emplace(hoid, host_object_state()); }
63✔
33

34
        void task_manager::notify_host_object_destroyed(const host_object_id hoid) {
49✔
35
                assert(m_host_objects.count(hoid) != 0);
49✔
36
                m_host_objects.erase(hoid);
49✔
37
        }
49✔
38

39
        const task* task_manager::find_task(task_id tid) const { return m_task_buffer.find_task(tid); }
×
40

41
        bool task_manager::has_task(task_id tid) const { return m_task_buffer.has_task(tid); }
11✔
42

43
        // Note that we assume tasks are not modified after their initial creation, which is why
44
        // we don't need to worry about thread-safety after returning the task pointer.
45
        const task* task_manager::get_task(task_id tid) const { return m_task_buffer.get_task(tid); }
6,672✔
46

47
        void task_manager::notify_horizon_reached(task_id horizon_tid) {
847✔
48
                // m_latest_horizon_reached does not need synchronization (see definition), all other accesses are implicitly synchronized.
49

50
                assert(m_task_buffer.get_task(horizon_tid)->get_type() == task_type::horizon);
847✔
51
                assert(!m_latest_horizon_reached || *m_latest_horizon_reached < horizon_tid);
847✔
52
                assert(m_latest_epoch_reached.get() < horizon_tid);
847✔
53

54
                if(m_latest_horizon_reached) { m_latest_epoch_reached.set(*m_latest_horizon_reached); }
847✔
55

56
                m_latest_horizon_reached = horizon_tid;
847✔
57
        }
847✔
58

59
        void task_manager::notify_epoch_reached(task_id epoch_tid) {
566✔
60
                // m_latest_horizon_reached does not need synchronization (see definition), all other accesses are implicitly synchronized.
61

62
                assert(get_task(epoch_tid)->get_type() == task_type::epoch);
566✔
63
                assert(!m_latest_horizon_reached || *m_latest_horizon_reached < epoch_tid);
566✔
64
                assert(m_latest_epoch_reached.get() < epoch_tid);
566✔
65

66
                m_latest_epoch_reached.set(epoch_tid);
566✔
67
                m_latest_horizon_reached = std::nullopt; // Any non-applied horizon is now behind the epoch and will therefore never become an epoch itself
566✔
68
        }
566✔
69

70
        void task_manager::await_epoch(task_id epoch) { m_latest_epoch_reached.await(epoch); }
565✔
71

72
        region<3> get_requirements(const task& tsk, buffer_id bid, const std::vector<sycl::access::mode>& modes) {
8,177✔
73
                const auto& access_map = tsk.get_buffer_access_map();
8,177✔
74
                const subrange<3> full_range{tsk.get_global_offset(), tsk.get_global_size()};
8,177✔
75
                box_vector<3> boxes;
8,177✔
76
                for(auto m : modes) {
44,248✔
77
                        const auto req = access_map.get_mode_requirements(bid, m, tsk.get_dimensions(), full_range, tsk.get_global_size());
36,073✔
78
                        boxes.insert(boxes.end(), req.get_boxes().begin(), req.get_boxes().end());
36,071✔
79
                }
36,071✔
80
                return region(std::move(boxes));
16,350✔
81
        }
8,177✔
82

83
        void task_manager::compute_dependencies(task& tsk) {
6,039✔
84
                using namespace sycl::access;
85

86
                const auto& access_map = tsk.get_buffer_access_map();
6,039✔
87

88
                auto buffers = access_map.get_accessed_buffers();
6,039✔
89
                for(const auto& reduction : tsk.get_reductions()) {
6,175✔
90
                        buffers.emplace(reduction.bid);
136✔
91
                }
92

93
                const box<3> scalar_box({0, 0, 0}, {1, 1, 1});
6,039✔
94

95
                for(const auto bid : buffers) {
10,249✔
96
                        auto& buffer = m_buffers.at(bid);
4,219✔
97
                        const auto modes = access_map.get_access_modes(bid);
4,219✔
98

99
                        std::optional<reduction_info> reduction;
4,219✔
100
                        for(const auto& maybe_reduction : tsk.get_reductions()) {
4,414✔
101
                                if(maybe_reduction.bid == bid) {
196✔
102
                                        if(reduction) { throw std::runtime_error(fmt::format("Multiple reductions attempt to write buffer {} in task {}", bid, tsk.get_id())); }
137✔
103
                                        reduction = maybe_reduction;
135✔
104
                                }
105
                        }
106

107
                        if(reduction && !modes.empty()) {
4,218✔
108
                                throw std::runtime_error(
109
                                    fmt::format("Buffer {} is both required through an accessor and used as a reduction output in task {}", bid, tsk.get_id()));
8✔
110
                        }
111

112
                        // Determine reader dependencies
113
                        if(std::any_of(modes.cbegin(), modes.cend(), detail::access::mode_traits::is_consumer) || (reduction.has_value() && reduction->init_from_buffer)) {
4,214✔
114
                                auto read_requirements = get_requirements(tsk, bid, {detail::access::consumer_modes.cbegin(), detail::access::consumer_modes.cend()});
10,698✔
115
                                if(reduction.has_value()) { read_requirements = region_union(read_requirements, scalar_box); }
3,566✔
116
                                const auto last_writers = buffer.last_writers.get_region_values(read_requirements);
3,566✔
117

118
                                box_vector<3> uninitialized_reads;
3,566✔
119
                                for(const auto& [box, writer] : last_writers) {
7,891✔
120
                                        // host-initialized buffers are last-written by the current epoch
121
                                        if(writer.has_value()) {
4,325✔
122
                                                add_dependency(tsk, *m_task_buffer.get_task(*writer), dependency_kind::true_dep, dependency_origin::dataflow);
4,309✔
123
                                        } else if(m_policy.uninitialized_read_error != error_policy::ignore) {
16✔
124
                                                uninitialized_reads.push_back(box);
5✔
125
                                        }
126
                                }
127
                                if(!uninitialized_reads.empty()) {
3,566✔
128
                                        utils::report_error(m_policy.uninitialized_read_error,
16✔
129
                                            "{} declares a reading access on uninitialized {} {}. Make sure to construct the accessor with no_init if possible.",
130
                                            print_task_debug_label(tsk, true /* title case */), print_buffer_debug_label(bid), region(std::move(uninitialized_reads)));
26✔
131
                                }
132
                        }
3,570✔
133

134
                        // Update last writers and determine anti-dependencies
135
                        if(std::any_of(modes.cbegin(), modes.cend(), detail::access::mode_traits::is_producer) || reduction.has_value()) {
4,212✔
136
                                auto write_requirements = get_requirements(tsk, bid, {detail::access::producer_modes.cbegin(), detail::access::producer_modes.cend()});
10,109✔
137
                                if(reduction.has_value()) { write_requirements = region_union(write_requirements, scalar_box); }
3,367✔
138
                                if(write_requirements.empty()) continue;
3,367✔
139

140
                                const auto last_writers = buffer.last_writers.get_region_values(write_requirements);
3,355✔
141
                                for(auto& p : last_writers) {
6,736✔
142
                                        if(p.second == std::nullopt) continue;
3,381✔
143
                                        task* last_writer = m_task_buffer.get_task(*p.second);
2,580✔
144

145
                                        // Determine anti-dependencies by looking at all the dependents of the last writing task
146
                                        bool has_anti_dependents = false;
2,580✔
147

148
                                        for(auto dependent : last_writer->get_dependents()) {
6,256✔
149
                                                if(dependent.node->get_id() == tsk.get_id()) {
3,676✔
150
                                                        // This can happen
151
                                                        // - if a task writes to two or more buffers with the same last writer
152
                                                        // - if the task itself also needs read access to that buffer (R/W access)
153
                                                        continue;
2,434✔
154
                                                }
155
                                                const auto dependent_read_requirements =
1,242✔
156
                                                    get_requirements(*dependent.node, bid, {detail::access::consumer_modes.cbegin(), detail::access::consumer_modes.cend()});
3,726✔
157
                                                // Only add an anti-dependency if we are really writing over the region read by this task
158
                                                if(!region_intersection(write_requirements, dependent_read_requirements).empty()) {
1,242✔
159
                                                        add_dependency(tsk, *dependent.node, dependency_kind::anti_dep, dependency_origin::dataflow);
361✔
160
                                                        has_anti_dependents = true;
361✔
161
                                                }
162
                                        }
1,242✔
163

164
                                        if(!has_anti_dependents) {
2,580✔
165
                                                // If no intermediate consumers exist, add an anti-dependency on the last writer directly.
166
                                                // Note that unless this task is a pure producer, a true dependency will be created and this is a no-op.
167
                                                // While it might not always make total sense to have anti-dependencies between (pure) producers without an
168
                                                // intermediate consumer, we at least have a defined behavior, and the thus enforced ordering of tasks
169
                                                // likely reflects what the user expects.
170
                                                add_dependency(tsk, *last_writer, dependency_kind::anti_dep, dependency_origin::dataflow);
2,332✔
171
                                        }
172
                                }
173

174
                                buffer.last_writers.update_region(write_requirements, tsk.get_id());
3,355✔
175
                        }
3,367✔
176
                }
4,219✔
177

178
                for(const auto& side_effect : tsk.get_side_effect_map()) {
6,255✔
179
                        const auto [hoid, order] = side_effect;
225✔
180
                        auto& host_object = m_host_objects.at(hoid);
225✔
181
                        if(host_object.last_side_effect.has_value()) {
225✔
182
                                add_dependency(tsk, *m_task_buffer.get_task(*host_object.last_side_effect), dependency_kind::true_dep, dependency_origin::dataflow);
171✔
183
                        }
184
                        host_object.last_side_effect = tsk.get_id();
225✔
185
                }
186

187
                if(auto cgid = tsk.get_collective_group_id(); cgid != 0) {
6,030✔
188
                        if(auto prev = m_last_collective_tasks.find(cgid); prev != m_last_collective_tasks.end()) {
66✔
189
                                add_dependency(tsk, *m_task_buffer.get_task(prev->second), dependency_kind::true_dep, dependency_origin::collective_group_serialization);
17✔
190
                                m_last_collective_tasks.erase(prev);
17✔
191
                        }
192
                        m_last_collective_tasks.emplace(cgid, tsk.get_id());
66✔
193
                }
194

195
                // Tasks without any other true-dependency must depend on the last epoch to ensure they cannot be re-ordered before the epoch
196
                if(const auto deps = tsk.get_dependencies();
6,030✔
197
                    std::none_of(deps.begin(), deps.end(), [](const task::dependency d) { return d.kind == dependency_kind::true_dep; })) {
10,233✔
198
                        add_dependency(tsk, *m_task_buffer.get_task(m_epoch_for_new_tasks), dependency_kind::true_dep, dependency_origin::last_epoch);
1,955✔
199
                }
200
        }
12,069✔
201

202
        task& task_manager::register_task_internal(task_ring_buffer::reservation&& reserve, std::unique_ptr<task> task) {
7,104✔
203
                auto& task_ref = *task;
7,104✔
204
                assert(task != nullptr);
7,104✔
205
                m_task_buffer.put(std::move(reserve), std::move(task));
7,104✔
206
                m_execution_front.insert(&task_ref);
7,104✔
207
                return task_ref;
7,104✔
208
        }
209

210
        void task_manager::invoke_callbacks(const task* tsk) const {
7,095✔
211
                for(const auto& cb : m_task_callbacks) {
12,446✔
212
                        cb(tsk);
5,351✔
213
                }
214
                if(m_task_recorder != nullptr) {
7,095✔
215
                        m_task_recorder->record(task_record(*tsk, [this](const buffer_id bid) { return m_buffers.at(bid).debug_name; }));
2,865✔
216
                }
217
        }
7,095✔
218

219
        void task_manager::add_dependency(task& depender, task& dependee, dependency_kind kind, dependency_origin origin) {
13,331✔
220
                assert(&depender != &dependee);
13,331✔
221
                depender.add_dependency({&dependee, kind, origin});
13,331✔
222
                m_execution_front.erase(&dependee);
13,331✔
223
                m_max_pseudo_critical_path_length = std::max(m_max_pseudo_critical_path_length, depender.get_pseudo_critical_path_length());
13,331✔
224
        }
13,331✔
225

226
        bool task_manager::need_new_horizon() const {
5,161✔
227
                const bool need_seq_horizon = m_max_pseudo_critical_path_length - m_current_horizon_critical_path_length >= m_task_horizon_step_size;
5,161✔
228
                const bool need_para_horizon = static_cast<int>(m_execution_front.size()) >= m_task_horizon_max_parallelism;
5,161✔
229
                const bool need_horizon = need_seq_horizon || need_para_horizon;
5,161✔
230
                CELERITY_TRACE("Horizon decision: {} - seq: {} para: {} - crit_p: {} exec_f: {}", need_horizon, need_seq_horizon, need_para_horizon,
5,161✔
231
                    m_current_horizon_critical_path_length, m_execution_front.size());
232
                return need_horizon;
10,322✔
233
        }
234

235
        task& task_manager::reduce_execution_front(task_ring_buffer::reservation&& reserve, std::unique_ptr<task> new_front) {
1,850✔
236
                // add dependencies from a copy of the front to this task
237
                const auto current_front = m_execution_front;
1,850✔
238
                for(task* front_task : current_front) {
6,036✔
239
                        add_dependency(*new_front, *front_task, dependency_kind::true_dep, dependency_origin::execution_front);
4,186✔
240
                }
241
                assert(m_execution_front.empty());
1,850✔
242
                return register_task_internal(std::move(reserve), std::move(new_front));
3,700✔
243
        }
1,850✔
244

245
        void task_manager::set_epoch_for_new_tasks(const task_id epoch) {
1,791✔
246
                // apply the new epoch to buffers_last_writers and last_collective_tasks data structs
247
                for(auto& [_, buffer] : m_buffers) {
3,001✔
248
                        buffer.last_writers.apply_to_values([epoch](const std::optional<task_id> tid) -> std::optional<task_id> {
1,210✔
249
                                if(!tid) return tid;
2,044✔
250
                                return {std::max(epoch, *tid)};
1,803✔
251
                        });
252
                }
253
                for(auto& [_, tid] : m_last_collective_tasks) {
1,874✔
254
                        tid = std::max(epoch, tid);
83✔
255
                }
256
                for(auto& [_, host_object] : m_host_objects) {
1,859✔
257
                        if(host_object.last_side_effect.has_value() && *host_object.last_side_effect < epoch) { host_object.last_side_effect = epoch; }
68!
258
                }
259

260
                m_epoch_for_new_tasks = epoch;
1,791✔
261
        }
1,791✔
262

263
        task_id task_manager::generate_horizon_task() {
1,065✔
264
                auto reserve = m_task_buffer.reserve_task_entry(await_free_task_slot_callback());
1,065✔
265
                const auto tid = reserve.get_tid();
1,065✔
266

267
                m_current_horizon_critical_path_length = m_max_pseudo_critical_path_length;
1,065✔
268
                const auto previous_horizon = m_current_horizon;
1,065✔
269
                m_current_horizon = tid;
1,065✔
270

271
                task& new_horizon = reduce_execution_front(std::move(reserve), task::make_horizon(*m_current_horizon));
1,065✔
272
                if(previous_horizon) { set_epoch_for_new_tasks(*previous_horizon); }
1,065✔
273

274
                invoke_callbacks(&new_horizon);
1,065✔
275
                return tid;
1,065✔
276
        }
1,065✔
277

278
        task_id task_manager::generate_epoch_task(epoch_action action) {
785✔
279
                auto reserve = m_task_buffer.reserve_task_entry(await_free_task_slot_callback());
785✔
280
                const auto tid = reserve.get_tid();
785✔
281

282
                task& new_epoch = reduce_execution_front(std::move(reserve), task::make_epoch(tid, action));
785✔
283
                compute_dependencies(new_epoch);
785✔
284
                set_epoch_for_new_tasks(tid);
785✔
285

286
                m_current_horizon = std::nullopt; // this horizon is now behind the epoch_for_new_tasks, so it will never become an epoch itself
785✔
287
                m_current_horizon_critical_path_length = m_max_pseudo_critical_path_length; // the explicit epoch resets the need to create horizons
785✔
288

289
                invoke_callbacks(&new_epoch);
785✔
290

291
                // On shutdown, attempt to detect suspiciously high numbers of previous user-generated epochs
292
                if(action != epoch_action::shutdown) {
785✔
293
                        m_num_user_epochs_generated++;
399✔
294
                } else if(m_num_user_command_groups_submitted > 100 && m_num_user_epochs_generated * 10 >= m_num_user_command_groups_submitted) {
386✔
295
                        CELERITY_WARN("Your program appears to call queue::wait() excessively, which may lead to performance degradation. Consider using queue::fence() "
4✔
296
                                      "for data-dependent branching and employ queue::wait() for timing only on a very coarse granularity.");
297
                }
298

299
                return tid;
785✔
300
        }
785✔
301

302
        task_id task_manager::generate_fence_task(buffer_access_map access_map, side_effect_map side_effects, std::unique_ptr<fence_promise> fence_promise) {
84✔
303
                auto reserve = m_task_buffer.reserve_task_entry(await_free_task_slot_callback());
84✔
304
                const auto tid = reserve.get_tid();
84✔
305
                task& tsk = register_task_internal(std::move(reserve), task::make_fence(tid, std::move(access_map), std::move(side_effects), std::move(fence_promise)));
84✔
306
                compute_dependencies(tsk);
84✔
307
                invoke_callbacks(&tsk);
84✔
308
                return tid;
84✔
309
        }
84✔
310

311
        task_id task_manager::get_first_in_flight_epoch() const {
3✔
312
                task_id current_horizon = 0;
3✔
313
                task_id latest_epoch = m_latest_epoch_reached.get();
3✔
314
                // we need either one epoch or two horizons that have yet to be executed
315
                // so that it is possible for task slots to be freed in the future
316
                for(const auto& tsk : m_task_buffer) {
2,061✔
317
                        if(tsk->get_id() <= latest_epoch) continue;
2,059✔
318
                        if(tsk->get_type() == task_type::epoch) {
1,033!
319
                                return tsk->get_id();
×
320
                        } else if(tsk->get_type() == task_type::horizon) {
1,033✔
321
                                if(current_horizon) return current_horizon;
2✔
322
                                current_horizon = tsk->get_id();
1✔
323
                        }
324
                }
325
                return latest_epoch;
2✔
326
        }
327

328
        task_ring_buffer::wait_callback task_manager::await_free_task_slot_callback() {
7,651✔
329
                return [&](task_id previous_free_tid) {
7,651✔
330
                        if(get_first_in_flight_epoch() == m_latest_epoch_reached.get()) {
3✔
331
                                // verify that the epoch didn't get reached between the invocation of the callback and the in flight check
332
                                if(m_latest_epoch_reached.get() < previous_free_tid + 1) {
2✔
333
                                        throw std::runtime_error("Exhausted task slots with no horizons or epochs in flight."
334
                                                                 "\nLikely due to generating a very large number of tasks with no dependencies.");
1✔
335
                                }
336
                        }
337
                        task_id reached_epoch = m_latest_epoch_reached.await(previous_free_tid + 1);
2✔
338
                        m_task_buffer.delete_up_to(reached_epoch);
2✔
339
                };
7,653✔
340
        }
341

342
        std::string task_manager::print_buffer_debug_label(const buffer_id bid) const { return utils::make_buffer_debug_label(bid, m_buffers.at(bid).debug_name); }
4✔
343

344
} // namespace detail
345
} // namespace celerity
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc