• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

celerity / celerity-runtime / 8340717687

19 Mar 2024 09:36AM UTC coverage: 94.724% (+0.09%) from 94.63%
8340717687

Pull #252

github

fknorr
Add new communicator / receive_arbiter infrastructure
Pull Request #252: [IDAG] Communication & Receive Arbitration

2967 of 3316 branches covered (89.48%)

Branch coverage included in aggregate %.

299 of 301 new or added lines in 6 files covered. (99.34%)

1 existing line in 1 file now uncovered.

6872 of 7071 relevant lines covered (97.19%)

195749.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.76
/src/mpi_communicator.cc
1
#include "mpi_communicator.h"
2
#include "log.h"
3
#include "mpi_support.h"
4
#include "ranges.h"
5

6
#include <climits>
7
#include <cstddef>
8

9
#include <mpi.h>
10

11
namespace celerity::detail::mpi_detail {
12

13
/// async_event wrapper around an MPI_Request.
14
class mpi_event final : public async_event_impl {
15
  public:
16
        explicit mpi_event(MPI_Request req) : m_req(req) {}
200✔
17

18
        mpi_event(const mpi_event&) = delete;
19
        mpi_event(mpi_event&&) = delete;
20
        mpi_event& operator=(const mpi_event&) = delete;
21
        mpi_event& operator=(mpi_event&&) = delete;
22

23
        ~mpi_event() override {
400✔
24
                // MPI_Request_free is always incorrect for our use case: events originate from an Isend or Irecv, which must ensure that the user-provided buffer
25
                // remains live until the operation has completed.
26
                MPI_Wait(&m_req, MPI_STATUS_IGNORE);
200✔
27
        }
400✔
28

29
        bool is_complete() const override {
8,471✔
30
                int flag = -1;
8,471✔
31
                MPI_Test(&m_req, &flag, MPI_STATUS_IGNORE);
8,471✔
32
                return flag != 0;
8,471✔
33
        }
34

35
  private:
36
        mutable MPI_Request m_req;
37
};
38

39
constexpr int pilot_exchange_tag = mpi_support::TAG_COMMUNICATOR;
40
constexpr int first_message_tag = pilot_exchange_tag + 1;
41

42
constexpr int message_id_to_mpi_tag(message_id msgid) {
200✔
43
        // If the resulting tag would overflow INT_MAX in a long-running program with many nodes, we wrap around to `first_message_tag` instead, assuming that
44
        // there will never be a way to cause temporal ambiguity between transfers that are 2^31 message ids apart.
45
        msgid %= static_cast<message_id>(INT_MAX - first_message_tag);
200✔
46
        return first_message_tag + static_cast<int>(msgid);
200✔
47
}
48

49
constexpr int node_id_to_mpi_rank(const node_id nid) {
214✔
50
        assert(nid <= static_cast<node_id>(INT_MAX));
214✔
51
        return static_cast<int>(nid);
214✔
52
}
53

54
constexpr node_id mpi_rank_to_node_id(const int rank) {
366✔
55
        assert(rank >= 0);
366✔
56
        return static_cast<node_id>(rank);
366✔
57
}
58

59
/// Strides that only differ e.g. in their dim0 allocation size are equivalent when adjusting the base pointer. This not only improves mpi_communicator type
60
/// cache efficiency, but is in fact necessary to make sure all boxes that instruction_graph_generator emits for send instructions and inbound pilots
61
/// representable in the 32-bit integer world of MPI.
62
/// @tparam Void Either `void` or `const void`.
63
template <typename Void>
64
constexpr std::tuple<Void*, communicator::stride> normalize_strided_pointer(Void* ptr, communicator::stride stride) {
200✔
65
        using byte_pointer_t = std::conditional_t<std::is_const_v<Void>, const std::byte*, std::byte*>;
66

67
        // drop leading buffer dimensions with extent 1, which allows us to do pointer adjustment in d1 / d2
68
        while(stride.allocation_range[0] == 1 && stride.allocation_range[1] * stride.allocation_range[2] > 1) {
248✔
69
                stride.allocation_range[0] = stride.allocation_range[1], stride.allocation_range[1] = stride.allocation_range[2], stride.allocation_range[2] = 1;
48✔
70
                stride.transfer.range[0] = stride.transfer.range[1], stride.transfer.range[1] = stride.transfer.range[2], stride.transfer.range[2] = 1;
48✔
71
                stride.transfer.offset[0] = stride.transfer.offset[1], stride.transfer.offset[1] = stride.transfer.offset[2], stride.transfer.offset[2] = 0;
48✔
72
        }
73

74
        // adjust base pointer to remove the offset
75
        const auto offset_elements = stride.transfer.offset[0] * stride.allocation_range[1] * stride.allocation_range[2];
200✔
76
        ptr = static_cast<byte_pointer_t>(ptr) + offset_elements * stride.element_size;
200✔
77
        stride.transfer.offset[0] = 0;
200✔
78

79
        // clamp allocation size to subrange (MPI will not access memory beyond subrange.range anyway)
80
        stride.allocation_range[0] = stride.transfer.range[0];
200✔
81

82
        // TODO we can normalize further if we accept arbitrarily large scalar types (via MPI contiguous / struct types):
83
        //          - collapse fast dimensions if contiguous via `stride.element_size *= stride.subrange.range[d]`
84
        //   - factorize stride coordinates: `element_size *= gcd(allocation[0], offset[0], range[0], allocation[1], ...)`
85
        // Doing all this will complicate instruction_graph_generator_detail::split_into_communicator_compatible_boxes though.
86
        return {ptr, stride};
200✔
87
}
88

89
} // namespace celerity::detail::mpi_detail
90

91
namespace celerity::detail {
92

93
mpi_communicator::mpi_communicator(collective_clone_from_tag /* tag */, MPI_Comm mpi_comm) : m_mpi_comm(MPI_COMM_NULL) {
1,770✔
94
        assert(mpi_comm != MPI_COMM_NULL);
1,767✔
95
#if MPI_VERSION < 3
96
        // MPI 2 only has Comm_dup - we assume that the user has not done any obscure things to MPI_COMM_WORLD
97
        MPI_Comm_dup(mpi_comm, &m_mpi_comm);
98
#else
99
        // MPI >= 3.0 provides MPI_Comm_dup_with_info, which allows us to reset all implementation hints on the communicator to our liking
100
        MPI_Info info;
1,767✔
101
        MPI_Info_create(&info);
1,767✔
102
        // See the OpenMPI manpage for MPI_Comm_set_info for keys and values
103
        MPI_Info_set(info, "mpi_assert_no_any_tag", "true");       // promise never to use MPI_ANY_TAG (we _do_ use MPI_ANY_SOURCE for pilots)
1,770✔
104
        MPI_Info_set(info, "mpi_assert_exact_length", "true");     // promise to exactly match sizes between corresponding MPI_Send and MPI_Recv calls
1,770✔
105
        MPI_Info_set(info, "mpi_assert_allow_overtaking", "true"); // we do not care about message ordering since we disambiguate by tag
1,770✔
106
        MPI_Comm_dup_with_info(mpi_comm, info, &m_mpi_comm);
1,770✔
107
        MPI_Info_free(&info);
1,770✔
108
#endif
109
}
1,770✔
110

111
mpi_communicator::~mpi_communicator() {
3,348✔
112
        // All asynchronous sends / receives must have completed at this point - unfortunately we have no easy way of checking this here.
113

114
        // Await the completion of all outbound pilot sends. The blocking-wait should usually be unnecessary because completion of payload-sends should imply
115
        // completion of the outbound-pilot sends, although there is no real guarantee of this given MPI's freedom to buffer transfers however it likes.
116
        // MPI_Wait will also free the async request, so we use this function unconditionally.
117
        for(auto& outbound : m_outbound_pilots) {
1,750✔
118
                MPI_Wait(&outbound.request, MPI_STATUS_IGNORE);
6✔
119
        }
120

121
        // We always re-start the pilot Irecv immediately, so we need to MPI_Cancel the last such request (and then free it using MPI_Wait).
122
        if(m_inbound_pilot.request != MPI_REQUEST_NULL) {
1,739✔
123
                MPI_Cancel(&m_inbound_pilot.request);
6✔
124
                MPI_Wait(&m_inbound_pilot.request, MPI_STATUS_IGNORE);
6✔
125
        }
126

127
        // MPI_Comm_free is itself a collective, but since this call happens from a destructor we implicitly guarantee that it cant' be re-ordered against any
128
        // other collective operation on this communicator.
129
        MPI_Comm_free(&m_mpi_comm);
1,739✔
130
}
3,364✔
131

132
size_t mpi_communicator::get_num_nodes() const {
352✔
133
        int size = -1;
352✔
134
        MPI_Comm_size(m_mpi_comm, &size);
352✔
135
        assert(size > 0);
352✔
136
        return static_cast<size_t>(size);
352✔
137
}
138

139
node_id mpi_communicator::get_local_node_id() const {
352✔
140
        int rank = -1;
352✔
141
        MPI_Comm_rank(m_mpi_comm, &rank);
352✔
142
        return mpi_detail::mpi_rank_to_node_id(rank);
704✔
143
}
144

145
void mpi_communicator::send_outbound_pilot(const outbound_pilot& pilot) {
14✔
146
        CELERITY_DEBUG("[mpi] pilot -> N{} (MSG{}, {}, {})", pilot.to, pilot.message.id, pilot.message.transfer_id, pilot.message.box);
14✔
147

148
        assert(pilot.to < get_num_nodes());
14✔
149
        assert(pilot.to != get_local_node_id());
14✔
150

151
        // Initiate Isend as early as possible to hide latency.
152
        in_flight_pilot newly_in_flight;
14✔
153
        newly_in_flight.message = std::make_unique<pilot_message>(pilot.message);
14✔
154
        MPI_Isend(newly_in_flight.message.get(), sizeof *newly_in_flight.message, MPI_BYTE, mpi_detail::node_id_to_mpi_rank(pilot.to),
14✔
155
            mpi_detail::pilot_exchange_tag, m_mpi_comm, &newly_in_flight.request);
156

157
        // Collect finished sends (TODO consider rate-limiting this to avoid quadratic behavior)
158
        constexpr auto pilot_send_finished = [](in_flight_pilot& already_in_flight) {
14✔
159
                int flag = -1;
8✔
160
                MPI_Test(&already_in_flight.request, &flag, MPI_STATUS_IGNORE);
8✔
161
                return already_in_flight.request == MPI_REQUEST_NULL;
8✔
162
        };
163
        m_outbound_pilots.erase(std::remove_if(m_outbound_pilots.begin(), m_outbound_pilots.end(), pilot_send_finished), m_outbound_pilots.end());
14✔
164

165
        // Keep allocation until Isend has completed
166
        m_outbound_pilots.push_back(std::move(newly_in_flight));
14✔
167
}
28✔
168

169
std::vector<inbound_pilot> mpi_communicator::poll_inbound_pilots() {
9,283✔
170
        // Irecv needs to be called initially, and after receiving each pilot to enqueue the next operation.
171
        const auto begin_receiving_next_pilot = [this] {
9,283✔
172
                assert(m_inbound_pilot.message != nullptr);
20✔
173
                assert(m_inbound_pilot.request == MPI_REQUEST_NULL);
20✔
174
                MPI_Irecv(m_inbound_pilot.message.get(), sizeof *m_inbound_pilot.message, MPI_BYTE, MPI_ANY_SOURCE, mpi_detail::pilot_exchange_tag, m_mpi_comm,
20✔
175
                    &m_inbound_pilot.request);
176
        };
9,303✔
177

178
        if(m_inbound_pilot.request == MPI_REQUEST_NULL) {
9,283✔
179
                // This is the first call to poll_inbound_pilots, spin up the pilot-receiving machinery - we don't do this unconditionally in the constructor
180
                // because communicators for collective groups do not deal with pilots
181
                m_inbound_pilot.message = std::make_unique<pilot_message>();
6✔
182
                begin_receiving_next_pilot();
6✔
183
        }
184

185
        // MPI might have received and buffered multiple inbound pilots, collect all of them in a loop
186
        std::vector<inbound_pilot> received_pilots;
9,283✔
187
        for(;;) {
188
                int flag = -1;
9,297✔
189
                MPI_Status status;
9,297✔
190
                MPI_Test(&m_inbound_pilot.request, &flag, &status);
9,297✔
191
                if(flag == 0 /* incomplete */) {
9,297✔
192
                        return received_pilots; // no more pilots in queue, we're done collecting
18,566✔
193
                }
194

195
                const inbound_pilot pilot{mpi_detail::mpi_rank_to_node_id(status.MPI_SOURCE), *m_inbound_pilot.message};
14✔
196
                begin_receiving_next_pilot(); // initiate the next receive asap
14✔
197

198
                CELERITY_DEBUG("[mpi] pilot <- N{} (MSG{}, {} {})", pilot.from, pilot.message.id, pilot.message.transfer_id, pilot.message.box);
14✔
199
                received_pilots.push_back(pilot);
14✔
200
        }
14✔
NEW
201
}
×
202

203
async_event mpi_communicator::send_payload(const node_id to, const message_id msgid, const void* const base, const stride& stride) {
100✔
204
        CELERITY_DEBUG("[mpi] payload -> N{} (MSG{}) from {} ({}) {}x{}", to, msgid, base, stride.allocation_range, stride.transfer, stride.element_size);
100✔
205

206
        assert(to < get_num_nodes());
100✔
207
        assert(to != get_local_node_id());
100✔
208

209
        MPI_Request req = MPI_REQUEST_NULL;
100✔
210
        const auto [adjusted_base, normalized_stride] = mpi_detail::normalize_strided_pointer(base, stride);
100✔
211
        MPI_Isend(
100✔
212
            adjusted_base, 1, get_array_type(normalized_stride), mpi_detail::node_id_to_mpi_rank(to), mpi_detail::message_id_to_mpi_tag(msgid), m_mpi_comm, &req);
213
        return make_async_event<mpi_detail::mpi_event>(req);
200✔
214
}
215

216
async_event mpi_communicator::receive_payload(const node_id from, const message_id msgid, void* const base, const stride& stride) {
100✔
217
        CELERITY_DEBUG("[mpi] payload <- N{} (MSG{}) into {} ({}) {}x{}", from, msgid, base, stride.allocation_range, stride.transfer, stride.element_size);
100✔
218

219
        assert(from < get_num_nodes());
100✔
220
        assert(from != get_local_node_id());
100✔
221

222
        MPI_Request req = MPI_REQUEST_NULL;
100✔
223
        const auto [adjusted_base, normalized_stride] = mpi_detail::normalize_strided_pointer(base, stride);
100✔
224
        MPI_Irecv(
100✔
225
            adjusted_base, 1, get_array_type(normalized_stride), mpi_detail::node_id_to_mpi_rank(from), mpi_detail::message_id_to_mpi_tag(msgid), m_mpi_comm, &req);
226
        return make_async_event<mpi_detail::mpi_event>(req);
200✔
227
}
228

229
std::unique_ptr<communicator> mpi_communicator::collective_clone() { return std::make_unique<mpi_communicator>(collective_clone_from, m_mpi_comm); }
1,534✔
230

231
void mpi_communicator::collective_barrier() { MPI_Barrier(m_mpi_comm); }
1,535✔
232

233
MPI_Datatype mpi_communicator::get_scalar_type(const size_t bytes) {
90✔
234
        if(const auto it = m_scalar_type_cache.find(bytes); it != m_scalar_type_cache.end()) { return it->second.get(); }
90!
235

236
        assert(bytes > 0);
90✔
237
        assert(bytes <= static_cast<size_t>(INT_MAX));
90✔
238
        MPI_Datatype type = MPI_DATATYPE_NULL;
90✔
239
        MPI_Type_contiguous(static_cast<int>(bytes), MPI_BYTE, &type);
90✔
240
        MPI_Type_commit(&type);
90✔
241
        m_scalar_type_cache.emplace(bytes, unique_datatype(type));
90✔
242
        return type;
90✔
243
}
244

245
MPI_Datatype mpi_communicator::get_array_type(const stride& stride) {
200✔
246
        if(const auto it = m_array_type_cache.find(stride); it != m_array_type_cache.end()) { return it->second.get(); }
200✔
247

248
        const int dims = detail::get_effective_dims(stride.allocation_range);
90✔
249
        assert(detail::get_effective_dims(stride.transfer) <= dims);
90✔
250

251
        // MPI (understandably) does not recognize a 0-dimensional subarray as a scalar
252
        if(dims == 0) { return get_scalar_type(stride.element_size); }
90✔
253

254
        // TODO - can we get runaway behavior from constructing too many MPI data types, especially with Spectrum MPI?
255
        // TODO - eagerly create MPI types ahead-of-time whenever we send or receive a pilot to reduce latency?
256

257
        int size_array[3];
58✔
258
        int subsize_array[3];
58✔
259
        int start_array[3];
58✔
260
        for(int d = 0; d < 3; ++d) {
232✔
261
                // The instruction graph generator must only ever emit transfers which can be described with a signed-int stride
262
                assert(stride.allocation_range[d] <= static_cast<size_t>(INT_MAX));
174✔
263
                assert(stride.transfer.range[d] <= static_cast<size_t>(INT_MAX));
174✔
264
                assert(stride.transfer.offset[d] <= static_cast<size_t>(INT_MAX));
174✔
265
                size_array[d] = static_cast<int>(stride.allocation_range[d]);
174✔
266
                subsize_array[d] = static_cast<int>(stride.transfer.range[d]);
174✔
267
                start_array[d] = static_cast<int>(stride.transfer.offset[d]);
174✔
268
        }
269

270
        MPI_Datatype type = MPI_DATATYPE_NULL;
58✔
271
        MPI_Type_create_subarray(dims, size_array, subsize_array, start_array, MPI_ORDER_C, get_scalar_type(stride.element_size), &type);
58✔
272
        MPI_Type_commit(&type);
58✔
273

274
        m_array_type_cache.emplace(stride, unique_datatype(type));
58✔
275
        return type;
58✔
276
}
277

278
void mpi_communicator::datatype_deleter::operator()(MPI_Datatype dtype) const { //
148✔
279
        MPI_Type_free(&dtype);
148✔
280
}
148✔
281

282
} // namespace celerity::detail
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc