• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 13012068652

28 Jan 2025 01:59PM UTC coverage: 64.71% (+0.01%) from 64.699%
13012068652

Pull #14724

github

web-flow
Merge b15562560 into db18c3a17
Pull Request #14724: dnsdist: Add meson support

38328 of 90334 branches covered (42.43%)

Branch coverage included in aggregate %.

361 of 513 new or added lines in 35 files covered. (70.37%)

42 existing lines in 13 files now uncovered.

128150 of 166934 relevant lines covered (76.77%)

4540890.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.1
/pdns/dnsdistdist/dnsdist-tcp.cc
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22

23
#include <thread>
24
#include <netinet/tcp.h>
25
#include <queue>
26
#include <boost/format.hpp>
27

28
#include "dnsdist.hh"
29
#include "dnsdist-concurrent-connections.hh"
30
#include "dnsdist-dnsparser.hh"
31
#include "dnsdist-ecs.hh"
32
#include "dnsdist-edns.hh"
33
#include "dnsdist-nghttp2-in.hh"
34
#include "dnsdist-proxy-protocol.hh"
35
#include "dnsdist-rings.hh"
36
#include "dnsdist-tcp.hh"
37
#include "dnsdist-tcp-downstream.hh"
38
#include "dnsdist-downstream-connection.hh"
39
#include "dnsdist-tcp-upstream.hh"
40
#include "dnsparser.hh"
41
#include "dolog.hh"
42
#include "gettime.hh"
43
#include "lock.hh"
44
#include "sstuff.hh"
45
#include "tcpiohandler.hh"
46
#include "tcpiohandler-mplexer.hh"
47
#include "threadname.hh"
48

49
/* TCP: the grand design.
50
   We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
51
   An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
52
   we will not go there.
53

54
   In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
55
   This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
56
   to guarantee performance.
57

58
   So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
59
   So whenever an answer comes in, we know where it needs to go.
60

61
   Let's start naively.
62
*/
63

64
std::atomic<uint64_t> g_tcpStatesDumpRequested{0};
65

66
LockGuarded<std::map<ComboAddress, size_t, ComboAddress::addressOnlyLessThan>> dnsdist::IncomingConcurrentTCPConnectionsManager::s_tcpClientsConcurrentConnectionsCount;
67

68
IncomingTCPConnectionState::~IncomingTCPConnectionState()
69
{
2,360✔
70
  dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(d_ci.remote);
2,360✔
71

72
  if (d_ci.cs != nullptr) {
2,360!
73
    timeval now{};
2,360✔
74
    gettimeofday(&now, nullptr);
2,360✔
75

76
    auto diff = now - d_connectionStartTime;
2,360✔
77
    d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000 + diff.tv_usec / 1000);
2,360✔
78
  }
2,360✔
79

80
  // would have been done when the object is destroyed anyway,
81
  // but that way we make sure it's done before the ConnectionInfo is destroyed,
82
  // closing the descriptor, instead of relying on the declaration order of the objects in the class
83
  d_handler.close();
2,360✔
84
}
2,360✔
85

86
dnsdist::Protocol IncomingTCPConnectionState::getProtocol() const
87
{
23,647✔
88
  if (d_ci.cs->dohFrontend) {
23,647✔
89
    return dnsdist::Protocol::DoH;
194✔
90
  }
194✔
91
  if (d_handler.isTLS()) {
23,453✔
92
    return dnsdist::Protocol::DoT;
20,903✔
93
  }
20,903✔
94
  return dnsdist::Protocol::DoTCP;
2,550✔
95
}
23,453✔
96

97
size_t IncomingTCPConnectionState::clearAllDownstreamConnections()
98
{
182✔
99
  return t_downstreamTCPConnectionsManager.clear();
182✔
100
}
182✔
101

102
std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now)
103
{
2,343✔
104
  auto downstream = getOwnedDownstreamConnection(backend, tlvs);
2,343✔
105

106
  if (!downstream) {
2,343✔
107
    /* we don't have a connection to this backend owned yet, let's get one (it might not be a fresh one, though) */
108
    downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(d_threadData.mplexer, backend, now, std::string());
2,311✔
109
    if (backend->d_config.useProxyProtocol) {
2,311✔
110
      registerOwnedDownstreamConnection(downstream);
15✔
111
    }
15✔
112
  }
2,311✔
113

114
  return downstream;
2,343✔
115
}
2,343✔
116

117
static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates);
118

119
TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientState*> tcpAcceptStates) :
120
  d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads)
121
{
332✔
122
  for (size_t idx = 0; idx < maxThreads; idx++) {
3,463✔
123
    addTCPClientThread(tcpAcceptStates);
3,131✔
124
  }
3,131✔
125
}
332✔
126

127
void TCPClientCollection::addTCPClientThread(std::vector<ClientState*>& tcpAcceptStates)
128
{
3,131✔
129
  try {
3,131✔
130
    const auto internalPipeBufferSize = dnsdist::configuration::getImmutableConfiguration().d_tcpInternalPipeBufferSize;
3,131✔
131

132
    auto [queryChannelSender, queryChannelReceiver] = pdns::channel::createObjectQueue<ConnectionInfo>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
3,131✔
133

134
    auto [crossProtocolQueryChannelSender, crossProtocolQueryChannelReceiver] = pdns::channel::createObjectQueue<CrossProtocolQuery>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
3,131✔
135

136
    auto [crossProtocolResponseChannelSender, crossProtocolResponseChannelReceiver] = pdns::channel::createObjectQueue<TCPCrossProtocolResponse>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
3,131✔
137

138
    vinfolog("Adding TCP Client thread");
3,131✔
139

140
    if (d_numthreads >= d_tcpclientthreads.size()) {
3,131!
141
      vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size());
×
142
      return;
×
143
    }
×
144

145
    TCPWorkerThread worker(std::move(queryChannelSender), std::move(crossProtocolQueryChannelSender));
3,131✔
146

147
    try {
3,131✔
148
      std::thread clientThread(tcpClientThread, std::move(queryChannelReceiver), std::move(crossProtocolQueryChannelReceiver), std::move(crossProtocolResponseChannelReceiver), std::move(crossProtocolResponseChannelSender), tcpAcceptStates);
3,131✔
149
      clientThread.detach();
3,131✔
150
    }
3,131✔
151
    catch (const std::runtime_error& e) {
3,131✔
152
      errlog("Error creating a TCP thread: %s", e.what());
×
153
      return;
×
154
    }
×
155

156
    d_tcpclientthreads.at(d_numthreads) = std::move(worker);
3,131✔
157
    ++d_numthreads;
3,131✔
158
  }
3,131✔
159
  catch (const std::exception& e) {
3,131✔
160
    errlog("Error creating TCP worker: %s", e.what());
×
161
  }
×
162
}
3,131✔
163

164
std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
165

166
static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
167
{
25,749✔
168
  IOState result = IOState::Done;
25,749✔
169

170
  while (state->active() && !state->d_queuedResponses.empty()) {
49,073✔
171
    DEBUGLOG("queue size is " << state->d_queuedResponses.size() << ", sending the next one");
23,341✔
172
    TCPResponse resp = std::move(state->d_queuedResponses.front());
23,341✔
173
    state->d_queuedResponses.pop_front();
23,341✔
174
    state->d_state = IncomingTCPConnectionState::State::idle;
23,341✔
175
    result = state->sendResponse(now, std::move(resp));
23,341✔
176
    if (result != IOState::Done) {
23,341✔
177
      return result;
17✔
178
    }
17✔
179
  }
23,341✔
180

181
  state->d_state = IncomingTCPConnectionState::State::idle;
25,732✔
182
  return IOState::Done;
25,732✔
183
}
25,749✔
184

185
void IncomingTCPConnectionState::handleResponseSent(TCPResponse& currentResponse, size_t sentBytes)
186
{
23,335✔
187
  if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) {
23,335✔
188
    return;
451✔
189
  }
451✔
190

191
  --d_currentQueriesCount;
22,884✔
192

193
  const auto& backend = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds;
22,884✔
194
  if (!currentResponse.d_idstate.selfGenerated && backend) {
22,884!
195
    const auto& ids = currentResponse.d_idstate;
2,407✔
196
    double udiff = ids.queryRealTime.udiff();
2,407✔
197
    vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", backend->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), sentBytes, udiff);
2,407✔
198

199
    auto backendProtocol = backend->getProtocol();
2,407✔
200
    if (backendProtocol == dnsdist::Protocol::DoUDP && !currentResponse.d_idstate.forwardedOverUDP) {
2,407✔
201
      backendProtocol = dnsdist::Protocol::DoTCP;
1,995✔
202
    }
1,995✔
203
    ::handleResponseSent(ids, udiff, d_ci.remote, backend->d_config.remote, static_cast<unsigned int>(sentBytes), currentResponse.d_cleartextDH, backendProtocol, true);
2,407✔
204
  }
2,407✔
205
  else {
20,477✔
206
    const auto& ids = currentResponse.d_idstate;
20,477✔
207
    ::handleResponseSent(ids, 0., d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false);
20,477✔
208
  }
20,477✔
209

210
  currentResponse.d_buffer.clear();
22,884✔
211
  currentResponse.d_connection.reset();
22,884✔
212
}
22,884✔
213

214
static void prependSizeToTCPQuery(PacketBuffer& buffer, size_t proxyProtocolPayloadSize)
215
{
2,548✔
216
  if (buffer.size() <= proxyProtocolPayloadSize) {
2,548!
217
    throw std::runtime_error("The payload size is smaller or equal to the buffer size");
×
218
  }
×
219

220
  uint16_t queryLen = proxyProtocolPayloadSize > 0 ? (buffer.size() - proxyProtocolPayloadSize) : buffer.size();
2,548✔
221
  const std::array<uint8_t, 2> sizeBytes{static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256)};
2,548✔
222
  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
223
     that could occur if we had to deal with the size during the processing,
224
     especially alignment issues */
225
  buffer.insert(buffer.begin() + static_cast<PacketBuffer::iterator::difference_type>(proxyProtocolPayloadSize), sizeBytes.begin(), sizeBytes.end());
2,548✔
226
}
2,548✔
227

228
bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now)
229
{
25,727✔
230
  if (d_hadErrors) {
25,727✔
231
    DEBUGLOG("not accepting new queries because we encountered some error during the processing already");
2✔
232
    return false;
2✔
233
  }
2✔
234

235
  // for DoH, this is already handled by the underlying library
236
  if (!d_ci.cs->dohFrontend && d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
25,725✔
237
    DEBUGLOG("not accepting new queries because we already have " << d_currentQueriesCount << " out of " << d_ci.cs->d_maxInFlightQueriesPerConn);
2,481✔
238
    return false;
2,481✔
239
  }
2,481✔
240

241
  const auto& currentConfig = dnsdist::configuration::getCurrentRuntimeConfiguration();
23,244✔
242
  if (currentConfig.d_maxTCPQueriesPerConn != 0 && d_queriesCount > currentConfig.d_maxTCPQueriesPerConn) {
23,244✔
243
    vinfolog("not accepting new queries from %s because it reached the maximum number of queries per conn (%d / %d)", d_ci.remote.toStringWithPort(), d_queriesCount, currentConfig.d_maxTCPQueriesPerConn);
208✔
244
    return false;
208✔
245
  }
208✔
246

247
  if (maxConnectionDurationReached(currentConfig.d_maxTCPConnectionDuration, now)) {
23,036!
248
    vinfolog("not accepting new queries from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
×
249
    return false;
×
250
  }
×
251

252
  return true;
23,036✔
253
}
23,036✔
254

255
void IncomingTCPConnectionState::resetForNewQuery()
256
{
23,036✔
257
  d_buffer.clear();
23,036✔
258
  d_currentPos = 0;
23,036✔
259
  d_querySize = 0;
23,036✔
260
  d_state = State::waitingForQuery;
23,036✔
261
}
23,036✔
262

263
std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs)
264
{
2,343✔
265
  auto connIt = d_ownedConnectionsToBackend.find(backend);
2,343✔
266
  if (connIt == d_ownedConnectionsToBackend.end()) {
2,343✔
267
    DEBUGLOG("no owned connection found for " << backend->getName());
2,311✔
268
    return nullptr;
2,311✔
269
  }
2,311✔
270

271
  for (auto& conn : connIt->second) {
32!
272
    if (conn->canBeReused(true) && conn->matchesTLVs(tlvs)) {
32!
273
      DEBUGLOG("Got one owned connection accepting more for " << backend->getName());
32✔
274
      conn->setReused();
32✔
275
      return conn;
32✔
276
    }
32✔
277
    DEBUGLOG("not accepting more for " << backend->getName());
×
278
  }
×
279

280
  return nullptr;
×
281
}
32✔
282

283
void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn)
284
{
15✔
285
  d_ownedConnectionsToBackend[conn->getDS()].push_front(conn);
15✔
286
}
15✔
287

288
/* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */
289
IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response)
290
{
23,206✔
291
  (void)now;
23,206✔
292
  d_state = State::sendingResponse;
23,206✔
293

294
  const auto responseSize = static_cast<uint16_t>(response.d_buffer.size());
23,206✔
295
  const std::array<uint8_t, 2> sizeBytes{static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
23,206✔
296
  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
297
     that could occur if we had to deal with the size during the processing,
298
     especially alignment issues */
299
  response.d_buffer.insert(response.d_buffer.begin(), sizeBytes.begin(), sizeBytes.end());
23,206✔
300
  d_currentPos = 0;
23,206✔
301
  d_currentResponse = std::move(response);
23,206✔
302

303
  try {
23,206✔
304
    auto iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
23,206✔
305
    if (iostate == IOState::Done) {
23,206✔
306
      DEBUGLOG("response sent from " << __PRETTY_FUNCTION__);
23,186✔
307
      handleResponseSent(d_currentResponse, d_currentResponse.d_buffer.size());
23,186✔
308
      return iostate;
23,186✔
309
    }
23,186✔
310
    d_lastIOBlocked = true;
20✔
311
    DEBUGLOG("partial write");
20✔
312
    return iostate;
20✔
313
  }
23,206✔
314
  catch (const std::exception& e) {
23,206✔
315
    vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what());
4!
316
    DEBUGLOG("Closing TCP client connection: " << e.what());
4✔
317
    ++d_ci.cs->tcpDiedSendingResponse;
4✔
318

319
    terminateClientConnection();
4✔
320

321
    return IOState::Done;
4✔
322
  }
4✔
323
}
23,206✔
324

325
void IncomingTCPConnectionState::terminateClientConnection()
326
{
2,135✔
327
  DEBUGLOG("terminating client connection");
2,135✔
328
  d_queuedResponses.clear();
2,135✔
329
  /* we have already released idle connections that could be reused,
330
     we don't care about the ones still waiting for responses */
331
  for (auto& backend : d_ownedConnectionsToBackend) {
2,135✔
332
    for (auto& conn : backend.second) {
14✔
333
      conn->release(true);
14✔
334
    }
14✔
335
  }
14✔
336
  d_ownedConnectionsToBackend.clear();
2,135✔
337

338
  /* meaning we will no longer be 'active' when the backend
339
     response or timeout comes in */
340
  d_ioState.reset();
2,135✔
341

342
  /* if we do have remaining async descriptors associated with this TLS
343
     connection, we need to defer the destruction of the TLS object until
344
     the engine has reported back, otherwise we have a use-after-free.. */
345
  auto afds = d_handler.getAsyncFDs();
2,135✔
346
  if (afds.empty()) {
2,135✔
347
    d_handler.close();
2,134✔
348
  }
2,134✔
349
  else {
1✔
350
    /* we might already be waiting, but we might also not because sometimes we have already been
351
       notified via the descriptor, not received Async again, but the async job still exists.. */
352
    auto state = shared_from_this();
1✔
353
    for (const auto desc : afds) {
1!
354
      try {
×
355
        state->d_threadData.mplexer->addReadFD(desc, handleAsyncReady, state);
×
356
      }
×
357
      catch (...) {
×
358
      }
×
359
    }
×
360
  }
1✔
361
}
2,135✔
362

363
void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response, bool fromBackend)
364
{
23,341✔
365
  // queue response
366
  state->d_queuedResponses.emplace_back(std::move(response));
23,341✔
367
  DEBUGLOG("queueing response, state is " << (int)state->d_state << ", queue size is now " << state->d_queuedResponses.size());
23,341✔
368

369
  // when the response comes from a backend, there is a real possibility that we are currently
370
  // idle, and thus not trying to send the response right away would make our ref count go to 0.
371
  // Even if we are waiting for a query, we will not wake up before the new query arrives or a
372
  // timeout occurs
373
  if (state->d_state == State::idle || state->d_state == State::waitingForQuery) {
23,341✔
374
    auto iostate = sendQueuedResponses(state, now);
23,087✔
375

376
    if (iostate == IOState::Done && state->active()) {
23,087✔
377
      if (state->canAcceptNewQueries(now)) {
23,069✔
378
        state->resetForNewQuery();
22,819✔
379
        state->d_state = State::waitingForQuery;
22,819✔
380
        iostate = IOState::NeedRead;
22,819✔
381
      }
22,819✔
382
      else {
250✔
383
        state->d_state = State::idle;
250✔
384
      }
250✔
385
    }
23,069✔
386

387
    // for the same reason we need to update the state right away, nobody will do that for us
388
    if (state->active()) {
23,087✔
389
      state->updateIO(iostate, now);
23,082✔
390
      // if we have not finished reading every available byte, we _need_ to do an actual read
391
      // attempt before waiting for the socket to become readable again, because if there is
392
      // buffered data available the socket might never become readable again.
393
      // This is true as soon as we deal with TLS because TLS records are processed one by
394
      // one and might not match what we see at the application layer, so data might already
395
      // be available in the TLS library's buffers. This is especially true when OpenSSL's
396
      // read-ahead mode is enabled because then it buffers even more than one SSL record
397
      // for performance reasons.
398
      if (fromBackend && !state->d_lastIOBlocked) {
23,082✔
399
        state->handleIO();
2,491✔
400
      }
2,491✔
401
    }
23,082✔
402
  }
23,087✔
403
}
23,341✔
404

405
void IncomingTCPConnectionState::handleAsyncReady([[maybe_unused]] int desc, FDMultiplexer::funcparam_t& param)
406
{
×
407
  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
×
408

409
  /* If we are here, the async jobs for this SSL* are finished
410
     so we should be able to remove all FDs */
411
  auto afds = state->d_handler.getAsyncFDs();
×
412
  for (const auto afd : afds) {
×
413
    try {
×
414
      state->d_threadData.mplexer->removeReadFD(afd);
×
415
    }
×
416
    catch (...) {
×
417
    }
×
418
  }
×
419

420
  if (state->active()) {
×
421
    /* and now we restart our own I/O state machine */
422
    state->handleIO();
×
423
  }
×
424
  else {
×
425
    /* we were only waiting for the engine to come back,
426
       to prevent a use-after-free */
427
    state->d_handler.close();
×
428
  }
×
429
}
×
430

431
void IncomingTCPConnectionState::updateIOForAsync(std::shared_ptr<IncomingTCPConnectionState>& conn)
432
{
×
433
  auto fds = conn->d_handler.getAsyncFDs();
×
434
  for (const auto desc : fds) {
×
435
    conn->d_threadData.mplexer->addReadFD(desc, handleAsyncReady, conn);
×
436
  }
×
437
  conn->d_ioState->update(IOState::Done, handleIOCallback, conn);
×
438
}
×
439

440
void IncomingTCPConnectionState::updateIO(IOState newState, const struct timeval& now)
441
{
45,858✔
442
  auto sharedPtrToConn = shared_from_this();
45,858✔
443
  if (newState == IOState::Async) {
45,858!
444
    updateIOForAsync(sharedPtrToConn);
×
445
    return;
×
446
  }
×
447

448
  d_ioState->update(newState, handleIOCallback, sharedPtrToConn, newState == IOState::NeedWrite ? getClientWriteTTD(now) : getClientReadTTD(now));
45,858✔
449
}
45,858✔
450

451
/* called from the backend code when a new response has been received */
452
void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response)
453
{
2,647✔
454
  if (std::this_thread::get_id() != d_creatorThreadID) {
2,647✔
455
    handleCrossProtocolResponse(now, std::move(response));
126✔
456
    return;
126✔
457
  }
126✔
458

459
  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
2,521✔
460

461
  if (!response.isAsync() && response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) {
2,521!
462
    // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool as no one else will be able to use it anyway
463
    if (!response.d_connection->willBeReusable(true)) {
40!
464
      // if it can't be reused even by us, well
465
      const auto connIt = state->d_ownedConnectionsToBackend.find(response.d_connection->getDS());
×
466
      if (connIt != state->d_ownedConnectionsToBackend.end()) {
×
467
        auto& list = connIt->second;
×
468

469
        for (auto it = list.begin(); it != list.end(); ++it) {
×
470
          if (*it == response.d_connection) {
×
471
            try {
×
472
              response.d_connection->release(true);
×
473
            }
×
474
            catch (const std::exception& e) {
×
475
              vinfolog("Error releasing connection: %s", e.what());
×
476
            }
×
477
            list.erase(it);
×
478
            break;
×
479
          }
×
480
        }
×
481
      }
×
482
    }
×
483
  }
40✔
484

485
  if (response.d_buffer.size() < sizeof(dnsheader)) {
2,521✔
486
    state->terminateClientConnection();
2✔
487
    return;
2✔
488
  }
2✔
489

490
  if (!response.isAsync()) {
2,519✔
491
    try {
2,435✔
492
      auto& ids = response.d_idstate;
2,435✔
493
      std::shared_ptr<DownstreamState> backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
2,435!
494
      if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend, dnsdist::configuration::getCurrentRuntimeConfiguration().d_allowEmptyResponse)) {
2,435!
495
        state->terminateClientConnection();
3✔
496
        return;
3✔
497
      }
3✔
498

499
      if (backend != nullptr) {
2,432!
500
        ++backend->responses;
2,432✔
501
      }
2,432✔
502

503
      DNSResponse dnsResponse(ids, response.d_buffer, backend);
2,432✔
504
      dnsResponse.d_incomingTCPState = state;
2,432✔
505

506
      memcpy(&response.d_cleartextDH, dnsResponse.getHeader().get(), sizeof(response.d_cleartextDH));
2,432✔
507

508
      if (!processResponse(response.d_buffer, dnsResponse, false)) {
2,432✔
509
        state->terminateClientConnection();
6✔
510
        return;
6✔
511
      }
6✔
512

513
      if (dnsResponse.isAsynchronous()) {
2,426✔
514
        /* we are done for now */
515
        return;
79✔
516
      }
79✔
517
    }
2,426✔
518
    catch (const std::exception& e) {
2,435✔
519
      vinfolog("Unexpected exception while handling response from backend: %s", e.what());
4!
520
      state->terminateClientConnection();
4✔
521
      return;
4✔
522
    }
4✔
523
  }
2,435✔
524

525
  ++dnsdist::metrics::g_stats.responses;
2,427✔
526
  ++state->d_ci.cs->responses;
2,427✔
527

528
  queueResponse(state, now, std::move(response), true);
2,427✔
529
}
2,427✔
530

531
struct TCPCrossProtocolResponse
532
{
533
  TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now) :
534
    d_response(std::move(response)), d_state(state), d_now(now)
535
  {
271✔
536
  }
271✔
537
  TCPCrossProtocolResponse(const TCPCrossProtocolResponse&) = delete;
538
  TCPCrossProtocolResponse& operator=(const TCPCrossProtocolResponse&) = delete;
539
  TCPCrossProtocolResponse(TCPCrossProtocolResponse&&) = delete;
540
  TCPCrossProtocolResponse& operator=(TCPCrossProtocolResponse&&) = delete;
541
  ~TCPCrossProtocolResponse() = default;
271✔
542

543
  TCPResponse d_response;
544
  std::shared_ptr<IncomingTCPConnectionState> d_state;
545
  struct timeval d_now;
546
};
547

548
class TCPCrossProtocolQuery : public CrossProtocolQuery
549
{
550
public:
551
  TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> backend, std::shared_ptr<IncomingTCPConnectionState> sender) :
552
    CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), backend), d_sender(std::move(sender))
553
  {
218✔
554
  }
218✔
555
  TCPCrossProtocolQuery(const TCPCrossProtocolQuery&) = delete;
556
  TCPCrossProtocolQuery& operator=(const TCPCrossProtocolQuery&) = delete;
557
  TCPCrossProtocolQuery(TCPCrossProtocolQuery&&) = delete;
558
  TCPCrossProtocolQuery& operator=(TCPCrossProtocolQuery&&) = delete;
559
  ~TCPCrossProtocolQuery() override = default;
218✔
560

561
  std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
562
  {
196✔
563
    return d_sender;
196✔
564
  }
196✔
565

566
  DNSQuestion getDQ() override
567
  {
168✔
568
    auto& ids = query.d_idstate;
168✔
569
    DNSQuestion dnsQuestion(ids, query.d_buffer);
168✔
570
    dnsQuestion.d_incomingTCPState = d_sender;
168✔
571
    return dnsQuestion;
168✔
572
  }
168✔
573

574
  DNSResponse getDR() override
575
  {
72✔
576
    auto& ids = query.d_idstate;
72✔
577
    DNSResponse dnsResponse(ids, query.d_buffer, downstream);
72✔
578
    dnsResponse.d_incomingTCPState = d_sender;
72✔
579
    return dnsResponse;
72✔
580
  }
72✔
581

582
private:
583
  std::shared_ptr<IncomingTCPConnectionState> d_sender;
584
};
585

586
std::unique_ptr<CrossProtocolQuery> IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& backend)
587
{
4✔
588
  return std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(state), backend, shared_from_this());
4✔
589
}
4✔
590

591
std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion)
592
{
187✔
593
  auto state = dnsQuestion.getIncomingTCPState();
187✔
594
  if (!state) {
187!
595
    throw std::runtime_error("Trying to create a TCP cross protocol query without a valid TCP state");
×
596
  }
×
597

598
  dnsQuestion.ids.origID = dnsQuestion.getHeader()->id;
187✔
599
  return std::make_unique<TCPCrossProtocolQuery>(std::move(dnsQuestion.getMutableData()), std::move(dnsQuestion.ids), nullptr, std::move(state));
187✔
600
}
187✔
601

602
void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response)
603
{
271✔
604
  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
271✔
605
  try {
271✔
606
    auto ptr = std::make_unique<TCPCrossProtocolResponse>(std::move(response), state, now);
271✔
607
    if (!state->d_threadData.crossProtocolResponseSender.send(std::move(ptr))) {
271!
608
      ++dnsdist::metrics::g_stats.tcpCrossProtocolResponsePipeFull;
×
609
      vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
×
610
    }
×
611
  }
271✔
612
  catch (const std::exception& e) {
271✔
613
    vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
×
614
  }
×
615
}
271✔
616

617
IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::handleQuery(PacketBuffer&& queryIn, const struct timeval& now, std::optional<int32_t> streamID)
618
{
23,054✔
619
  auto query = std::move(queryIn);
23,054✔
620
  if (query.size() < sizeof(dnsheader)) {
23,054✔
621
    ++dnsdist::metrics::g_stats.nonCompliantQueries;
4✔
622
    ++d_ci.cs->nonCompliantQueries;
4✔
623
    return QueryProcessingResult::TooSmall;
4✔
624
  }
4✔
625

626
  ++d_queriesCount;
23,050✔
627
  ++d_ci.cs->queries;
23,050✔
628
  ++dnsdist::metrics::g_stats.queries;
23,050✔
629

630
  if (d_handler.isTLS()) {
23,050✔
631
    auto tlsVersion = d_handler.getTLSVersion();
21,033✔
632
    switch (tlsVersion) {
21,033✔
633
    case LibsslTLSVersion::TLS10:
×
634
      ++d_ci.cs->tls10queries;
×
635
      break;
×
636
    case LibsslTLSVersion::TLS11:
×
637
      ++d_ci.cs->tls11queries;
×
638
      break;
×
639
    case LibsslTLSVersion::TLS12:
12✔
640
      ++d_ci.cs->tls12queries;
12✔
641
      break;
12✔
642
    case LibsslTLSVersion::TLS13:
21,021✔
643
      ++d_ci.cs->tls13queries;
21,021✔
644
      break;
21,021✔
645
    default:
×
646
      ++d_ci.cs->tlsUnknownqueries;
×
647
    }
21,033✔
648
  }
21,033✔
649

650
  auto state = shared_from_this();
23,050✔
651
  InternalQueryState ids;
23,050✔
652
  ids.origDest = d_proxiedDestination;
23,050✔
653
  ids.origRemote = d_proxiedRemote;
23,050✔
654
  ids.cs = d_ci.cs;
23,050✔
655
  ids.queryRealTime.start();
23,050✔
656
  if (streamID) {
23,050✔
657
    ids.d_streamID = *streamID;
152✔
658
  }
152✔
659

660
  auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true);
23,050✔
661
  if (dnsCryptResponse) {
23,050!
662
    TCPResponse response;
×
663
    d_state = State::idle;
×
664
    ++d_currentQueriesCount;
×
665
    queueResponse(state, now, std::move(response), false);
×
666
    return QueryProcessingResult::SelfAnswered;
×
667
  }
×
668

669
  {
23,050✔
670
    /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
671
    const dnsheader_aligned dnsHeader(query.data());
23,050✔
672
    if (!checkQueryHeaders(*dnsHeader, *d_ci.cs)) {
23,050✔
673
      return QueryProcessingResult::InvalidHeaders;
5✔
674
    }
5✔
675

676
    if (dnsHeader->qdcount == 0) {
23,045✔
677
      TCPResponse response;
3✔
678
      auto queryID = dnsHeader->id;
3✔
679
      dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
3✔
680
        header.rcode = RCode::NotImp;
3✔
681
        header.qr = true;
3✔
682
        return true;
3✔
683
      });
3✔
684
      response.d_idstate = std::move(ids);
3✔
685
      response.d_idstate.origID = queryID;
3✔
686
      response.d_idstate.selfGenerated = true;
3✔
687
      response.d_buffer = std::move(query);
3✔
688
      d_state = State::idle;
3✔
689
      ++d_currentQueriesCount;
3✔
690
      queueResponse(state, now, std::move(response), false);
3✔
691
      return QueryProcessingResult::SelfAnswered;
3✔
692
    }
3✔
693
  }
23,045✔
694

695
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast
696
  ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), static_cast<int>(query.size()), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
23,042✔
697
  ids.protocol = getProtocol();
23,042✔
698
  if (ids.dnsCryptQuery) {
23,042✔
699
    ids.protocol = dnsdist::Protocol::DNSCryptTCP;
15✔
700
  }
15✔
701

702
  DNSQuestion dnsQuestion(ids, query);
23,042✔
703
  dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) {
23,042✔
704
    const uint16_t* flags = getFlagsFromDNSHeader(&header);
23,040✔
705
    ids.origFlags = *flags;
23,040✔
706
    return true;
23,040✔
707
  });
23,040✔
708
  dnsQuestion.d_incomingTCPState = state;
23,042✔
709
  dnsQuestion.sni = d_handler.getServerNameIndication();
23,042✔
710

711
  if (d_proxyProtocolValues) {
23,042✔
712
    /* we need to copy them, because the next queries received on that connection will
713
       need to get the _unaltered_ values */
714
    dnsQuestion.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*d_proxyProtocolValues);
34✔
715
  }
34✔
716

717
  if (dnsQuestion.ids.qtype == QType::AXFR || dnsQuestion.ids.qtype == QType::IXFR) {
23,042✔
718
    dnsQuestion.ids.skipCache = true;
27✔
719
  }
27✔
720

721
  if (forwardViaUDPFirst()) {
23,042✔
722
    // if there was no EDNS, we add it with a large buffer size
723
    // so we can use UDP to talk to the backend.
724
    const dnsheader_aligned dnsHeader(query.data());
146✔
725
    if (dnsHeader->arcount == 0U) {
146✔
726
      if (addEDNS(query, 4096, false, 4096, 0)) {
138!
727
        dnsQuestion.ids.ednsAdded = true;
138✔
728
      }
138✔
729
    }
138✔
730
  }
146✔
731

732
  if (streamID) {
23,042✔
733
    auto unit = getDOHUnit(*streamID);
146✔
734
    if (unit) {
146!
735
      dnsQuestion.ids.du = std::move(unit);
146✔
736
    }
146✔
737
  }
146✔
738

739
  std::shared_ptr<DownstreamState> backend;
23,042✔
740
  auto result = processQuery(dnsQuestion, backend);
23,042✔
741

742
  if (result == ProcessQueryResult::Asynchronous) {
23,042✔
743
    /* we are done for now */
744
    ++d_currentQueriesCount;
108✔
745
    return QueryProcessingResult::Asynchronous;
108✔
746
  }
108✔
747

748
  if (streamID) {
22,934✔
749
    restoreDOHUnit(std::move(dnsQuestion.ids.du));
110✔
750
  }
110✔
751

752
  if (result == ProcessQueryResult::Drop) {
22,934✔
753
    return QueryProcessingResult::Dropped;
38✔
754
  }
38✔
755

756
  // the buffer might have been invalidated by now
757
  uint16_t queryID{0};
22,896✔
758
  {
22,896✔
759
    const auto dnsHeader = dnsQuestion.getHeader();
22,896✔
760
    queryID = dnsHeader->id;
22,896✔
761
  }
22,896✔
762

763
  if (result == ProcessQueryResult::SendAnswer) {
22,896✔
764
    TCPResponse response;
20,462✔
765
    {
20,462✔
766
      const auto dnsHeader = dnsQuestion.getHeader();
20,462✔
767
      memcpy(&response.d_cleartextDH, dnsHeader.get(), sizeof(response.d_cleartextDH));
20,462✔
768
    }
20,462✔
769
    response.d_idstate = std::move(ids);
20,462✔
770
    response.d_idstate.origID = queryID;
20,462✔
771
    response.d_idstate.selfGenerated = true;
20,462✔
772
    response.d_idstate.cs = d_ci.cs;
20,462✔
773
    response.d_buffer = std::move(query);
20,462✔
774

775
    d_state = State::idle;
20,462✔
776
    ++d_currentQueriesCount;
20,462✔
777
    queueResponse(state, now, std::move(response), false);
20,462✔
778
    return QueryProcessingResult::SelfAnswered;
20,462✔
779
  }
20,462✔
780

781
  if (result != ProcessQueryResult::PassToBackend || backend == nullptr) {
2,434!
782
    return QueryProcessingResult::NoBackend;
×
783
  }
×
784

785
  dnsQuestion.ids.origID = queryID;
2,434✔
786

787
  ++d_currentQueriesCount;
2,434✔
788

789
  std::string proxyProtocolPayload;
2,434✔
790
  if (backend->isDoH()) {
2,434✔
791
    vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), backend->getNameWithAddr());
27✔
792

793
    /* we need to do this _before_ creating the cross protocol query because
794
       after that the buffer will have been moved */
795
    if (backend->d_config.useProxyProtocol) {
27✔
796
      proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion);
1✔
797
    }
1✔
798

799
    auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(ids), backend, state);
27✔
800
    cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
27✔
801

802
    backend->passCrossProtocolQuery(std::move(cpq));
27✔
803
    return QueryProcessingResult::Forwarded;
27✔
804
  }
27✔
805
  if (!backend->isTCPOnly() && forwardViaUDPFirst()) {
2,407✔
806
    if (streamID) {
61!
807
      auto unit = getDOHUnit(*streamID);
61✔
808
      if (unit) {
61!
809
        dnsQuestion.ids.du = std::move(unit);
61✔
810
      }
61✔
811
    }
61✔
812
    if (assignOutgoingUDPQueryToBackend(backend, queryID, dnsQuestion, query)) {
61✔
813
      return QueryProcessingResult::Forwarded;
60✔
814
    }
60✔
815
    restoreDOHUnit(std::move(dnsQuestion.ids.du));
1✔
816
    // fallback to the normal flow
817
  }
1✔
818

819
  prependSizeToTCPQuery(query, 0);
2,347✔
820

821
  auto downstreamConnection = getDownstreamConnection(backend, dnsQuestion.proxyProtocolValues, now);
2,347✔
822

823
  if (backend->d_config.useProxyProtocol) {
2,347✔
824
    /* if we ever sent a TLV over a connection, we can never go back */
825
    if (!d_proxyProtocolPayloadHasTLV) {
47✔
826
      d_proxyProtocolPayloadHasTLV = dnsQuestion.proxyProtocolValues && !dnsQuestion.proxyProtocolValues->empty();
32!
827
    }
32✔
828

829
    proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion);
47✔
830
  }
47✔
831

832
  if (dnsQuestion.proxyProtocolValues) {
2,347✔
833
    downstreamConnection->setProxyProtocolValuesSent(std::move(dnsQuestion.proxyProtocolValues));
21✔
834
  }
21✔
835

836
  TCPQuery tcpquery(std::move(query), std::move(ids));
2,347✔
837
  tcpquery.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
2,347✔
838

839
  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), backend->getNameWithAddr());
2,347✔
840
  std::shared_ptr<TCPQuerySender> incoming = state;
2,347✔
841
  downstreamConnection->queueQuery(incoming, std::move(tcpquery));
2,347✔
842
  return QueryProcessingResult::Forwarded;
2,347✔
843
}
2,407✔
844

845
void IncomingTCPConnectionState::handleIOCallback(int desc, FDMultiplexer::funcparam_t& param)
846
{
2,253✔
847
  auto conn = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
2,253✔
848
  if (desc != conn->d_handler.getDescriptor()) {
2,253!
849
    // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay): __PRETTY_FUNCTION__ is fine
850
    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(desc) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor()));
×
851
  }
×
852

853
  conn->handleIO();
2,253✔
854
}
2,253✔
855

856
void IncomingTCPConnectionState::handleHandshakeDone(const struct timeval& now)
857
{
2,374✔
858
  if (d_handler.isTLS()) {
2,374✔
859
    if (!d_handler.hasTLSSessionBeenResumed()) {
421✔
860
      ++d_ci.cs->tlsNewSessions;
395✔
861
    }
395✔
862
    else {
26✔
863
      ++d_ci.cs->tlsResumptions;
26✔
864
    }
26✔
865
    if (d_handler.getResumedFromInactiveTicketKey()) {
421✔
866
      ++d_ci.cs->tlsInactiveTicketKey;
8✔
867
    }
8✔
868
    if (d_handler.getUnknownTicketKey()) {
421✔
869
      ++d_ci.cs->tlsUnknownTicketKey;
6✔
870
    }
6✔
871
  }
421✔
872

873
  d_handshakeDoneTime = now;
2,374✔
874
}
2,374✔
875

876
IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::handleProxyProtocolPayload()
877
{
19✔
878
  do {
32✔
879
    DEBUGLOG("reading proxy protocol header");
32✔
880
    auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed, false, isProxyPayloadOutsideTLS());
32✔
881
    if (iostate == IOState::Done) {
32✔
882
      d_buffer.resize(d_currentPos);
27✔
883
      ssize_t remaining = isProxyHeaderComplete(d_buffer);
27✔
884
      if (remaining == 0) {
27✔
885
        vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", d_ci.remote.toStringWithPort());
3!
886
        ++dnsdist::metrics::g_stats.proxyProtocolInvalid;
3✔
887
        return ProxyProtocolResult::Error;
3✔
888
      }
3✔
889
      if (remaining < 0) {
24✔
890
        d_proxyProtocolNeed += -remaining;
13✔
891
        d_buffer.resize(d_currentPos + d_proxyProtocolNeed);
13✔
892
        /* we need to keep reading, since we might have buffered data */
893
      }
13✔
894
      else {
11✔
895
        /* proxy header received */
896
        std::vector<ProxyProtocolValue> proxyProtocolValues;
11✔
897
        if (!handleProxyProtocol(d_ci.remote, true, dnsdist::configuration::getCurrentRuntimeConfiguration().d_ACL, d_buffer, d_proxiedRemote, d_proxiedDestination, proxyProtocolValues)) {
11!
898
          vinfolog("Error handling the Proxy Protocol received from TCP client %s", d_ci.remote.toStringWithPort());
×
899
          return ProxyProtocolResult::Error;
×
900
        }
×
901

902
        if (!proxyProtocolValues.empty()) {
11✔
903
          d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
9✔
904
        }
9✔
905

906
        d_currentPos = 0;
11✔
907
        d_proxyProtocolNeed = 0;
11✔
908
        d_buffer.clear();
11✔
909
        return ProxyProtocolResult::Done;
11✔
910
      }
11✔
911
    }
24✔
912
    else {
5✔
913
      d_lastIOBlocked = true;
5✔
914
    }
5✔
915
  } while (active() && !d_lastIOBlocked);
32✔
916

917
  return ProxyProtocolResult::Reading;
5✔
918
}
19✔
919

920
IOState IncomingTCPConnectionState::handleHandshake(const struct timeval& now)
921
{
2,314✔
922
  DEBUGLOG("doing handshake");
2,314✔
923
  auto iostate = d_handler.tryHandshake();
2,314✔
924
  if (iostate == IOState::Done) {
2,314✔
925
    DEBUGLOG("handshake done");
2,176✔
926
    handleHandshakeDone(now);
2,176✔
927

928
    if (d_ci.cs != nullptr && d_ci.cs->d_enableProxyProtocol && !isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
2,176!
929
      d_state = State::readingProxyProtocolHeader;
16✔
930
      d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
16✔
931
      d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
16✔
932
    }
16✔
933
    else {
2,160✔
934
      d_state = State::readingQuerySize;
2,160✔
935
    }
2,160✔
936
  }
2,176✔
937
  else {
138✔
938
    d_lastIOBlocked = true;
138✔
939
  }
138✔
940

941
  return iostate;
2,314✔
942
}
2,314✔
943

944
IOState IncomingTCPConnectionState::handleIncomingQueryReceived(const struct timeval& now)
945
{
22,898✔
946
  DEBUGLOG("query received");
22,898✔
947
  d_buffer.resize(d_querySize);
22,898✔
948

949
  d_state = State::idle;
22,898✔
950
  auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt);
22,898✔
951
  switch (processingResult) {
22,898✔
952
  case QueryProcessingResult::TooSmall:
×
953
    /* fall-through */
954
  case QueryProcessingResult::InvalidHeaders:
3✔
955
    /* fall-through */
956
  case QueryProcessingResult::Dropped:
36✔
957
    /* fall-through */
958
  case QueryProcessingResult::NoBackend:
36!
959
    terminateClientConnection();
36✔
960
    ;
36✔
961
  default:
22,894✔
962
    break;
22,894✔
963
  }
22,898✔
964

965
  /* the state might have been updated in the meantime, we don't want to override it
966
     in that case */
967
  if (active() && d_state != State::idle) {
22,894✔
968
    if (d_ioState->isWaitingForRead()) {
20,428✔
969
      return IOState::NeedRead;
20,424✔
970
    }
20,424✔
971
    if (d_ioState->isWaitingForWrite()) {
4!
972
      return IOState::NeedWrite;
4✔
973
    }
4✔
974
    return IOState::Done;
×
975
  }
4✔
976
  return IOState::Done;
2,466✔
977
};
22,894✔
978

979
void IncomingTCPConnectionState::handleExceptionDuringIO(const std::exception& exp)
980
{
2,037✔
981
  if (d_state == State::idle || d_state == State::waitingForQuery) {
2,037✔
982
    /* no need to increase any counters in that case, the client is simply done with us */
983
  }
1,683✔
984
  else if (d_state == State::doingHandshake || d_state == State::readingProxyProtocolHeader || d_state == State::waitingForQuery || d_state == State::readingQuerySize || d_state == State::readingQuery) {
354!
985
    ++d_ci.cs->tcpDiedReadingQuery;
354✔
986
  }
354✔
987
  else if (d_state == State::sendingResponse) {
×
988
    /* unlikely to happen here, the exception should be handled in sendResponse() */
989
    ++d_ci.cs->tcpDiedSendingResponse;
×
990
  }
×
991

992
  if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) {
2,037!
993
    DEBUGLOG("Got an exception while handling TCP query: " << exp.what());
354✔
994
    vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (d_ioState->isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), exp.what());
354!
995
  }
354✔
996
  else {
1,683✔
997
    vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), exp.what());
1,683✔
998
    DEBUGLOG("Closing TCP client connection: " << exp.what());
1,683✔
999
  }
1,683✔
1000
  /* remove this FD from the IO multiplexer */
1001
  terminateClientConnection();
2,037✔
1002
}
2,037✔
1003

1004
bool IncomingTCPConnectionState::readIncomingQuery(const timeval& now, IOState& iostate)
1005
{
27,051✔
1006
  if (!d_lastIOBlocked && (d_state == State::waitingForQuery || d_state == State::readingQuerySize)) {
27,051!
1007
    DEBUGLOG("reading query size");
26,904✔
1008
    d_buffer.resize(sizeof(uint16_t));
26,904✔
1009
    iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t));
26,904✔
1010
    if (d_currentPos > 0) {
26,904✔
1011
      /* if we got at least one byte, we can't go around sending responses */
1012
      d_state = State::readingQuerySize;
22,909✔
1013
    }
22,909✔
1014

1015
    if (iostate == IOState::Done) {
26,904✔
1016
      DEBUGLOG("query size received");
22,905✔
1017
      d_state = State::readingQuery;
22,905✔
1018
      d_querySizeReadTime = now;
22,905✔
1019
      if (d_queriesCount == 0) {
22,905✔
1020
        d_firstQuerySizeReadTime = now;
1,820✔
1021
      }
1,820✔
1022
      d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1);
22,905✔
1023
      if (d_querySize < sizeof(dnsheader)) {
22,905✔
1024
        /* go away */
1025
        terminateClientConnection();
2✔
1026
        return true;
2✔
1027
      }
2✔
1028

1029
      d_buffer.resize(d_querySize);
22,903✔
1030
      d_currentPos = 0;
22,903✔
1031
    }
22,903✔
1032
    else {
3,999✔
1033
      d_lastIOBlocked = true;
3,999✔
1034
    }
3,999✔
1035
  }
26,904✔
1036

1037
  if (!d_lastIOBlocked && d_state == State::readingQuery) {
27,049!
1038
    DEBUGLOG("reading query");
23,050✔
1039
    iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize);
23,050✔
1040
    if (iostate == IOState::Done) {
23,050✔
1041
      iostate = handleIncomingQueryReceived(now);
22,898✔
1042
    }
22,898✔
1043
    else {
152✔
1044
      d_lastIOBlocked = true;
152✔
1045
    }
152✔
1046
  }
23,050✔
1047

1048
  return false;
27,049✔
1049
}
27,051✔
1050

1051
void IncomingTCPConnectionState::handleIO()
1052
{
6,831✔
1053
  // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
1054
  // even though the underlying socket is not ready, so we need to actually ask for the data first
1055
  IOState iostate = IOState::Done;
6,831✔
1056
  timeval now{};
6,831✔
1057
  gettimeofday(&now, nullptr);
6,831✔
1058

1059
  do {
27,462✔
1060
    iostate = IOState::Done;
27,462✔
1061
    IOStateGuard ioGuard(d_ioState);
27,462✔
1062

1063
    if (maxConnectionDurationReached(dnsdist::configuration::getCurrentRuntimeConfiguration().d_maxTCPConnectionDuration, now)) {
27,462✔
1064
      vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
1!
1065
      // will be handled by the ioGuard
1066
      // handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1067
      return;
1✔
1068
    }
1✔
1069

1070
    d_lastIOBlocked = false;
27,461✔
1071

1072
    try {
27,461✔
1073
      if (d_state == State::starting) {
27,461✔
1074
        if (d_ci.cs != nullptr && d_ci.cs->d_enableProxyProtocol && isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
2,177!
1075
          d_state = State::readingProxyProtocolHeader;
1✔
1076
          d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
1✔
1077
          d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
1✔
1078
        }
1✔
1079
        else {
2,176✔
1080
          d_state = State::doingHandshake;
2,176✔
1081
        }
2,176✔
1082
      }
2,177✔
1083

1084
      if (d_state == State::doingHandshake) {
27,461✔
1085
        iostate = handleHandshake(now);
2,313✔
1086
      }
2,313✔
1087

1088
      if (!d_lastIOBlocked && d_state == State::readingProxyProtocolHeader) {
27,461✔
1089
        auto status = handleProxyProtocolPayload();
17✔
1090
        if (status == ProxyProtocolResult::Done) {
17✔
1091
          d_buffer.resize(sizeof(uint16_t));
9✔
1092

1093
          if (isProxyPayloadOutsideTLS()) {
9✔
1094
            d_state = State::doingHandshake;
1✔
1095
            iostate = handleHandshake(now);
1✔
1096
          }
1✔
1097
          else {
8✔
1098
            d_state = State::readingQuerySize;
8✔
1099
          }
8✔
1100
        }
9✔
1101
        else if (status == ProxyProtocolResult::Error) {
8✔
1102
          iostate = IOState::Done;
3✔
1103
        }
3✔
1104
        else {
5✔
1105
          iostate = IOState::NeedRead;
5✔
1106
        }
5✔
1107
      }
17✔
1108

1109
      if (!d_lastIOBlocked && (d_state == State::waitingForQuery || d_state == State::readingQuerySize || d_state == State::readingQuery)) {
27,461✔
1110
        if (readIncomingQuery(now, iostate)) {
27,051✔
1111
          return;
2✔
1112
        }
2✔
1113
      }
27,051✔
1114

1115
      if (!d_lastIOBlocked && d_state == State::sendingResponse) {
27,459✔
1116
        DEBUGLOG("sending response");
14✔
1117
        iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
14✔
1118
        if (iostate == IOState::Done) {
14!
1119
          DEBUGLOG("response sent from " << __PRETTY_FUNCTION__);
14✔
1120
          handleResponseSent(d_currentResponse, d_currentResponse.d_buffer.size());
14✔
1121
          d_state = State::idle;
14✔
1122
        }
14✔
1123
        else {
×
1124
          d_lastIOBlocked = true;
×
1125
        }
×
1126
      }
14✔
1127

1128
      if (active() && !d_lastIOBlocked && iostate == IOState::Done && (d_state == State::idle || d_state == State::waitingForQuery)) {
27,459!
1129
        // try sending queued responses
1130
        DEBUGLOG("send responses, if any");
2,662✔
1131
        auto state = shared_from_this();
2,662✔
1132
        iostate = sendQueuedResponses(state, now);
2,662✔
1133

1134
        if (!d_lastIOBlocked && active() && iostate == IOState::Done) {
2,662!
1135
          // if the query has been passed to a backend, or dropped, and the responses have been sent,
1136
          // we can start reading again
1137
          if (canAcceptNewQueries(now)) {
2,658✔
1138
            resetForNewQuery();
217✔
1139
            iostate = IOState::NeedRead;
217✔
1140
          }
217✔
1141
          else {
2,441✔
1142
            d_state = State::idle;
2,441✔
1143
            iostate = IOState::Done;
2,441✔
1144
          }
2,441✔
1145
        }
2,658✔
1146
      }
2,662✔
1147

1148
      if (d_state != State::idle && d_state != State::doingHandshake && d_state != State::readingProxyProtocolHeader && d_state != State::waitingForQuery && d_state != State::readingQuerySize && d_state != State::readingQuery && d_state != State::sendingResponse) {
27,459!
1149
        vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(d_state));
×
1150
      }
×
1151
    }
27,459✔
1152
    catch (const std::exception& exp) {
27,461✔
1153
      /* most likely an EOF because the other end closed the connection,
1154
         but it might also be a real IO error or something else.
1155
         Let's just drop the connection
1156
      */
1157
      handleExceptionDuringIO(exp);
2,037✔
1158
    }
2,037✔
1159

1160
    if (!active()) {
27,459✔
1161
      DEBUGLOG("state is no longer active");
2,105✔
1162
      return;
2,105✔
1163
    }
2,105✔
1164

1165
    auto sharedPtrToConn = shared_from_this();
25,354✔
1166
    if (iostate == IOState::Done) {
25,354✔
1167
      d_ioState->update(iostate, handleIOCallback, sharedPtrToConn);
2,444✔
1168
    }
2,444✔
1169
    else {
22,910✔
1170
      updateIO(iostate, now);
22,910✔
1171
    }
22,910✔
1172
    ioGuard.release();
25,354✔
1173
  } while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !d_lastIOBlocked);
25,354✔
1174
}
6,831✔
1175

1176
void IncomingTCPConnectionState::notifyIOError(const struct timeval& now, TCPResponse&& response)
1177
{
61✔
1178
  if (std::this_thread::get_id() != d_creatorThreadID) {
61✔
1179
    /* empty buffer will signal an IO error */
1180
    response.d_buffer.clear();
18✔
1181
    handleCrossProtocolResponse(now, std::move(response));
18✔
1182
    return;
18✔
1183
  }
18✔
1184

1185
  auto sharedPtrToConn = shared_from_this();
43✔
1186
  --sharedPtrToConn->d_currentQueriesCount;
43✔
1187
  sharedPtrToConn->d_hadErrors = true;
43✔
1188

1189
  if (sharedPtrToConn->d_state == State::sendingResponse) {
43✔
1190
    /* if we have responses to send, let's do that first */
1191
  }
2✔
1192
  else if (!sharedPtrToConn->d_queuedResponses.empty()) {
41!
1193
    /* stop reading and send what we have */
1194
    try {
×
1195
      auto iostate = sendQueuedResponses(sharedPtrToConn, now);
×
1196

1197
      if (sharedPtrToConn->active() && iostate != IOState::Done) {
×
1198
        // we need to update the state right away, nobody will do that for us
1199
        updateIO(iostate, now);
×
1200
      }
×
1201
    }
×
1202
    catch (const std::exception& e) {
×
1203
      vinfolog("Exception in notifyIOError: %s", e.what());
×
1204
    }
×
1205
  }
×
1206
  else {
41✔
1207
    // the backend code already tried to reconnect if it was possible
1208
    sharedPtrToConn->terminateClientConnection();
41✔
1209
  }
41✔
1210
}
43✔
1211

1212
static bool processXFRResponse(DNSResponse& dnsResponse)
1213
{
449✔
1214
  const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains;
449✔
1215
  const auto& xfrRespRuleActions = dnsdist::rules::getResponseRuleChain(chains, dnsdist::rules::ResponseRuleChain::XFRResponseRules);
449✔
1216

1217
  if (!applyRulesToResponse(xfrRespRuleActions, dnsResponse)) {
449!
1218
    return false;
×
1219
  }
×
1220

1221
  if (dnsResponse.isAsynchronous()) {
449!
1222
    return true;
×
1223
  }
×
1224

1225
  if (dnsResponse.ids.d_extendedError) {
449!
1226
    dnsdist::edns::addExtendedDNSError(dnsResponse.getMutableData(), dnsResponse.getMaximumSize(), dnsResponse.ids.d_extendedError->infoCode, dnsResponse.ids.d_extendedError->extraText);
×
1227
  }
×
1228

1229
  return true;
449✔
1230
}
449✔
1231

1232
void IncomingTCPConnectionState::handleXFRResponse(const struct timeval& now, TCPResponse&& response)
1233
{
449✔
1234
  if (std::this_thread::get_id() != d_creatorThreadID) {
449!
1235
    handleCrossProtocolResponse(now, std::move(response));
×
1236
    return;
×
1237
  }
×
1238

1239
  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
449✔
1240
  auto& ids = response.d_idstate;
449✔
1241
  std::shared_ptr<DownstreamState> backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
449!
1242
  DNSResponse dnsResponse(ids, response.d_buffer, backend);
449✔
1243
  dnsResponse.d_incomingTCPState = state;
449✔
1244
  memcpy(&response.d_cleartextDH, dnsResponse.getHeader().get(), sizeof(response.d_cleartextDH));
449✔
1245

1246
  if (!processXFRResponse(dnsResponse)) {
449!
1247
    state->terminateClientConnection();
×
1248
    return;
×
1249
  }
×
1250

1251
  queueResponse(state, now, std::move(response), true);
449✔
1252
}
449✔
1253

1254
void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write)
1255
{
17✔
1256
  vinfolog("Timeout while %s TCP client %s", (write ? "writing to" : "reading from"), state->d_ci.remote.toStringWithPort());
17!
1257
  DEBUGLOG("client timeout");
17✔
1258
  DEBUGLOG("Processed " << state->d_queriesCount << " queries, current count is " << state->d_currentQueriesCount << ", " << state->d_ownedConnectionsToBackend.size() << " owned connections, " << state->d_queuedResponses.size() << " response queued");
17✔
1259

1260
  if (write || state->d_currentQueriesCount == 0) {
17✔
1261
    ++state->d_ci.cs->tcpClientTimeouts;
11✔
1262
    state->d_ioState.reset();
11✔
1263
  }
11✔
1264
  else {
6✔
1265
    DEBUGLOG("Going idle");
6✔
1266
    /* we still have some queries in flight, let's just stop reading for now */
1267
    state->d_state = State::idle;
6✔
1268
    state->d_ioState->update(IOState::Done, handleIOCallback, state);
6✔
1269
  }
6✔
1270
}
17✔
1271

1272
static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1273
{
2,277✔
1274
  (void)pipefd;
2,277✔
1275
  auto* threadData = boost::any_cast<TCPClientThreadData*>(param);
2,277✔
1276

1277
  std::unique_ptr<ConnectionInfo> citmp{nullptr};
2,277✔
1278
  try {
2,277✔
1279
    auto tmp = threadData->queryReceiver.receive();
2,277✔
1280
    if (!tmp) {
2,277!
1281
      return;
×
1282
    }
×
1283
    citmp = std::move(*tmp);
2,277✔
1284
  }
2,277✔
1285
  catch (const std::exception& e) {
2,277✔
1286
    throw std::runtime_error("Error while reading from the TCP query channel: " + std::string(e.what()));
×
1287
  }
×
1288

1289
  g_tcpclientthreads->decrementQueuedCount();
2,277✔
1290

1291
  timeval now{};
2,277✔
1292
  gettimeofday(&now, nullptr);
2,277✔
1293

1294
  if (citmp->cs->dohFrontend) {
2,277✔
1295
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
192✔
1296
    auto state = std::make_shared<IncomingHTTP2Connection>(std::move(*citmp), *threadData, now);
192✔
1297
    state->handleIO();
192✔
1298
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
192✔
1299
  }
192✔
1300
  else {
2,085✔
1301
    auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
2,085✔
1302
    state->handleIO();
2,085✔
1303
  }
2,085✔
1304
}
2,277✔
1305

1306
static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1307
{
205✔
1308
  (void)pipefd;
205✔
1309
  auto* threadData = boost::any_cast<TCPClientThreadData*>(param);
205✔
1310

1311
  std::unique_ptr<CrossProtocolQuery> cpq{nullptr};
205✔
1312
  try {
205✔
1313
    auto tmp = threadData->crossProtocolQueryReceiver.receive();
205✔
1314
    if (!tmp) {
205!
1315
      return;
×
1316
    }
×
1317
    cpq = std::move(*tmp);
205✔
1318
  }
205✔
1319
  catch (const std::exception& e) {
205✔
1320
    throw std::runtime_error("Error while reading from the TCP cross-protocol channel: " + std::string(e.what()));
×
1321
  }
×
1322

1323
  timeval now{};
205✔
1324
  gettimeofday(&now, nullptr);
205✔
1325

1326
  std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender();
205✔
1327
  auto query = std::move(cpq->query);
205✔
1328
  auto downstreamServer = std::move(cpq->downstream);
205✔
1329

1330
  try {
205✔
1331
    auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
205✔
1332

1333
    prependSizeToTCPQuery(query.d_buffer, query.d_idstate.d_proxyProtocolPayloadSize);
205✔
1334

1335
    vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr());
205✔
1336

1337
    downstream->queueQuery(tqs, std::move(query));
205✔
1338
  }
205✔
1339
  catch (...) {
205✔
1340
    tqs->notifyIOError(now, std::move(query));
×
1341
  }
×
1342
}
205✔
1343

1344
static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param)
1345
{
271✔
1346
  (void)pipefd;
271✔
1347
  auto* threadData = boost::any_cast<TCPClientThreadData*>(param);
271✔
1348

1349
  std::unique_ptr<TCPCrossProtocolResponse> cpr{nullptr};
271✔
1350
  try {
271✔
1351
    auto tmp = threadData->crossProtocolResponseReceiver.receive();
271✔
1352
    if (!tmp) {
271!
1353
      return;
×
1354
    }
×
1355
    cpr = std::move(*tmp);
271✔
1356
  }
271✔
1357
  catch (const std::exception& e) {
271✔
1358
    throw std::runtime_error("Error while reading from the TCP cross-protocol response: " + std::string(e.what()));
×
1359
  }
×
1360

1361
  auto& response = *cpr;
271✔
1362

1363
  try {
271✔
1364
    if (response.d_response.d_buffer.empty()) {
271✔
1365
      response.d_state->notifyIOError(response.d_now, std::move(response.d_response));
24✔
1366
    }
24✔
1367
    else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) {
247!
1368
      response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
×
1369
    }
×
1370
    else {
247✔
1371
      response.d_state->handleResponse(response.d_now, std::move(response.d_response));
247✔
1372
    }
247✔
1373
  }
271✔
1374
  catch (...) {
271✔
1375
    /* no point bubbling up from there */
1376
  }
×
1377
}
271✔
1378

1379
struct TCPAcceptorParam
1380
{
1381
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1382
  ClientState& clientState;
1383
  ComboAddress local;
1384
  int socket{-1};
1385
};
1386

1387
static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData);
1388

1389
static void scanForTimeouts(const TCPClientThreadData& data, const timeval& now)
1390
{
6,435✔
1391
  auto expiredReadConns = data.mplexer->getTimeouts(now, false);
6,435✔
1392
  for (const auto& cbData : expiredReadConns) {
6,435!
1393
    if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
×
1394
      auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
×
1395
      if (cbData.first == state->d_handler.getDescriptor()) {
×
1396
        vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
×
1397
        state->handleTimeout(state, false);
×
1398
      }
×
1399
    }
×
1400
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1401
    else if (cbData.second.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
×
1402
      auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(cbData.second);
1403
      if (cbData.first == state->d_handler.getDescriptor()) {
×
1404
        vinfolog("Timeout (read) from remote H2 client %s", state->d_ci.remote.toStringWithPort());
×
1405
        std::shared_ptr<IncomingTCPConnectionState> parentState = state;
1406
        state->handleTimeout(parentState, false);
1407
      }
1408
    }
1409
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1410
    else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
×
1411
      auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
×
1412
      vinfolog("Timeout (read) from remote backend %s", conn->getBackendName());
×
1413
      conn->handleTimeout(now, false);
×
1414
    }
×
1415
  }
×
1416

1417
  auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
6,435✔
1418
  for (const auto& cbData : expiredWriteConns) {
6,435!
1419
    if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
×
1420
      auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
×
1421
      if (cbData.first == state->d_handler.getDescriptor()) {
×
1422
        vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
×
1423
        state->handleTimeout(state, true);
×
1424
      }
×
1425
    }
×
1426
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1427
    else if (cbData.second.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
×
1428
      auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(cbData.second);
1429
      if (cbData.first == state->d_handler.getDescriptor()) {
×
1430
        vinfolog("Timeout (write) from remote H2 client %s", state->d_ci.remote.toStringWithPort());
×
1431
        std::shared_ptr<IncomingTCPConnectionState> parentState = state;
1432
        state->handleTimeout(parentState, true);
1433
      }
1434
    }
1435
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1436
    else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
×
1437
      auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
×
1438
      vinfolog("Timeout (write) from remote backend %s", conn->getBackendName());
×
1439
      conn->handleTimeout(now, true);
×
1440
    }
×
1441
  }
×
1442
}
6,435✔
1443

1444
static void dumpTCPStates(const TCPClientThreadData& data)
1445
{
×
1446
  /* just to keep things clean in the output, debug only */
1447
  static std::mutex s_lock;
×
1448
  std::lock_guard<decltype(s_lock)> lck(s_lock);
×
1449
  if (g_tcpStatesDumpRequested > 0) {
×
1450
    /* no race here, we took the lock so it can only be increased in the meantime */
1451
    --g_tcpStatesDumpRequested;
×
1452
    infolog("Dumping the TCP states, as requested:");
×
1453
    data.mplexer->runForAllWatchedFDs([](bool isRead, int desc, const FDMultiplexer::funcparam_t& param, struct timeval ttd) {
×
1454
      timeval lnow{};
×
1455
      gettimeofday(&lnow, nullptr);
×
1456
      if (ttd.tv_sec > 0) {
×
1457
        infolog("- Descriptor %d is in %s state, TTD in %d", desc, (isRead ? "read" : "write"), (ttd.tv_sec - lnow.tv_sec));
×
1458
      }
×
1459
      else {
×
1460
        infolog("- Descriptor %d is in %s state, no TTD set", desc, (isRead ? "read" : "write"));
×
1461
      }
×
1462

1463
      if (param.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
×
1464
        auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
×
1465
        infolog(" - %s", state->toString());
×
1466
      }
×
1467
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1468
      else if (param.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
×
1469
        auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
1470
        infolog(" - %s", state->toString());
1471
      }
1472
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1473
      else if (param.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
×
1474
        auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param);
×
1475
        infolog(" - %s", conn->toString());
×
1476
      }
×
1477
      else if (param.type() == typeid(TCPClientThreadData*)) {
×
1478
        infolog(" - Worker thread pipe");
×
1479
      }
×
1480
    });
×
1481
    infolog("The TCP/DoT client cache has %d active and %d idle outgoing connections cached", t_downstreamTCPConnectionsManager.getActiveCount(), t_downstreamTCPConnectionsManager.getIdleCount());
×
1482
  }
×
1483
}
×
1484

1485
// NOLINTNEXTLINE(performance-unnecessary-value-param): you are wrong, clang-tidy, go home
1486
static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates)
1487
{
3,131✔
1488
  /* we get launched with a pipe on which we receive file descriptors from clients that we own
1489
     from that point on */
1490

1491
  setThreadName("dnsdist/tcpClie");
3,131✔
1492

1493
  try {
3,131✔
1494
    TCPClientThreadData data;
3,131✔
1495
    data.crossProtocolResponseSender = std::move(crossProtocolResponseSender);
3,131✔
1496
    data.queryReceiver = std::move(queryReceiver);
3,131✔
1497
    data.crossProtocolQueryReceiver = std::move(crossProtocolQueryReceiver);
3,131✔
1498
    data.crossProtocolResponseReceiver = std::move(crossProtocolResponseReceiver);
3,131✔
1499

1500
    data.mplexer->addReadFD(data.queryReceiver.getDescriptor(), handleIncomingTCPQuery, &data);
3,131✔
1501
    data.mplexer->addReadFD(data.crossProtocolQueryReceiver.getDescriptor(), handleCrossProtocolQuery, &data);
3,131✔
1502
    data.mplexer->addReadFD(data.crossProtocolResponseReceiver.getDescriptor(), handleCrossProtocolResponse, &data);
3,131✔
1503

1504
    /* only used in single acceptor mode for now */
1505
    std::vector<TCPAcceptorParam> acceptParams;
3,131✔
1506
    acceptParams.reserve(tcpAcceptStates.size());
3,131✔
1507

1508
    for (auto& state : tcpAcceptStates) {
3,131!
1509
      acceptParams.emplace_back(TCPAcceptorParam{*state, state->local, state->tcpFD});
×
1510
      for (const auto& [addr, socket] : state->d_additionalAddresses) {
×
1511
        acceptParams.emplace_back(TCPAcceptorParam{*state, addr, socket});
×
1512
      }
×
1513
    }
×
1514

1515
    auto acceptCallback = [&data](int socket, FDMultiplexer::funcparam_t& funcparam) {
3,131✔
NEW
1516
      (void)socket;
×
1517
      const auto* acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam);
×
1518
      acceptNewConnection(*acceptorParam, &data);
×
1519
    };
×
1520

1521
    for (const auto& param : acceptParams) {
3,131!
1522
      setNonBlocking(param.socket);
×
1523
      data.mplexer->addReadFD(param.socket, acceptCallback, &param);
×
1524
    }
×
1525

1526
    timeval now{};
3,131✔
1527
    gettimeofday(&now, nullptr);
3,131✔
1528
    time_t lastTimeoutScan = now.tv_sec;
3,131✔
1529

1530
    for (;;) {
24,535✔
1531
      data.mplexer->run(&now);
24,535✔
1532

1533
      try {
24,535✔
1534
        t_downstreamTCPConnectionsManager.cleanupClosedConnections(now);
24,535✔
1535

1536
        if (now.tv_sec > lastTimeoutScan) {
24,535✔
1537
          lastTimeoutScan = now.tv_sec;
6,603✔
1538
          scanForTimeouts(data, now);
6,603✔
1539

1540
          if (g_tcpStatesDumpRequested > 0) {
6,603!
1541
            dumpTCPStates(data);
×
1542
          }
×
1543
        }
6,603✔
1544
      }
24,535✔
1545
      catch (const std::exception& e) {
24,535✔
1546
        warnlog("Error in TCP worker thread: %s", e.what());
×
1547
      }
×
1548
    }
24,535✔
1549
  }
3,131✔
1550
  catch (const std::exception& e) {
3,131✔
1551
    errlog("Fatal error in TCP worker thread: %s", e.what());
×
1552
  }
×
1553
}
3,131✔
1554

1555
static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData)
1556
{
2,679✔
1557
  auto& clientState = param.clientState;
2,679✔
1558
  const bool checkACL = clientState.dohFrontend == nullptr || (!clientState.dohFrontend->d_trustForwardedForHeader && clientState.dohFrontend->d_earlyACLDrop);
2,679!
1559
  const int socket = param.socket;
2,679✔
1560
  bool tcpClientCountIncremented = false;
2,679✔
1561
  ComboAddress remote;
2,679✔
1562
  remote.sin4.sin_family = param.local.sin4.sin_family;
2,679✔
1563

1564
  tcpClientCountIncremented = false;
2,679✔
1565
  try {
2,679✔
1566
    socklen_t remlen = remote.getSocklen();
2,679✔
1567
    ConnectionInfo connInfo(&clientState);
2,679✔
1568
#ifdef HAVE_ACCEPT4
2,679✔
1569
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1570
    connInfo.fd = accept4(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
2,679✔
1571
#else
1572
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1573
    connInfo.fd = accept(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
1574
#endif
1575
    // will be decremented when the ConnectionInfo object is destroyed, no matter the reason
1576
    auto concurrentConnections = ++clientState.tcpCurrentConnections;
2,679✔
1577

1578
    if (connInfo.fd < 0) {
2,679!
1579
      throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
×
1580
    }
×
1581

1582
    if (checkACL && !dnsdist::configuration::getCurrentRuntimeConfiguration().d_ACL.match(remote)) {
2,679✔
1583
      ++dnsdist::metrics::g_stats.aclDrops;
9✔
1584
      vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
9✔
1585
      return;
9✔
1586
    }
9✔
1587

1588
    if (clientState.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > clientState.d_tcpConcurrentConnectionsLimit) {
2,670✔
1589
      vinfolog("Dropped TCP connection from %s because of concurrent connections limit", remote.toStringWithPort());
3✔
1590
      return;
3✔
1591
    }
3✔
1592

1593
    if (concurrentConnections > clientState.tcpMaxConcurrentConnections.load()) {
2,667✔
1594
      clientState.tcpMaxConcurrentConnections.store(concurrentConnections);
500✔
1595
    }
500✔
1596

1597
#ifndef HAVE_ACCEPT4
1598
    if (!setNonBlocking(connInfo.fd)) {
1599
      return;
1600
    }
1601
#endif
1602

1603
    setTCPNoDelay(connInfo.fd); // disable NAGLE
2,667✔
1604

1605
    const auto maxTCPQueuedConnections = dnsdist::configuration::getImmutableConfiguration().d_maxTCPQueuedConnections;
2,667✔
1606
    if (maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= maxTCPQueuedConnections) {
2,667!
1607
      vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
×
1608
      return;
×
1609
    }
×
1610

1611
    if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) {
2,667✔
1612
      vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
2✔
1613
      return;
2✔
1614
    }
2✔
1615
    tcpClientCountIncremented = true;
2,665✔
1616

1617
    vinfolog("Got TCP connection from %s", remote.toStringWithPort());
2,665✔
1618

1619
    connInfo.remote = remote;
2,665✔
1620

1621
    if (threadData == nullptr) {
2,665✔
1622
      if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(connInfo)))) {
2,277!
1623
        if (tcpClientCountIncremented) {
×
1624
          dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote);
×
1625
        }
×
1626
      }
×
1627
    }
2,277✔
1628
    else {
388✔
1629
      timeval now{};
388✔
1630
      gettimeofday(&now, nullptr);
388✔
1631

1632
      if (connInfo.cs->dohFrontend) {
388!
1633
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1634
        auto state = std::make_shared<IncomingHTTP2Connection>(std::move(connInfo), *threadData, now);
1635
        state->handleIO();
1636
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1637
      }
×
1638
      else {
388✔
1639
        auto state = std::make_shared<IncomingTCPConnectionState>(std::move(connInfo), *threadData, now);
388✔
1640
        state->handleIO();
388✔
1641
      }
388✔
1642
    }
388✔
1643
  }
2,665✔
1644
  catch (const std::exception& e) {
2,679✔
1645
    errlog("While reading a TCP question: %s", e.what());
×
1646
    if (tcpClientCountIncremented) {
×
1647
      dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote);
×
1648
    }
×
1649
  }
×
1650
  catch (...) {
2,679✔
1651
  }
×
1652
}
2,679✔
1653

1654
/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1655
   they will hand off to worker threads & spawn more of them if required
1656
*/
1657
#ifndef USE_SINGLE_ACCEPTOR_THREAD
1658
void tcpAcceptorThread(const std::vector<ClientState*>& states)
1659
{
388✔
1660
  setThreadName("dnsdist/tcpAcce");
388✔
1661

1662
  std::vector<TCPAcceptorParam> params;
388✔
1663
  params.reserve(states.size());
388✔
1664

1665
  for (const auto& state : states) {
388✔
1666
    params.emplace_back(TCPAcceptorParam{*state, state->local, state->tcpFD});
388✔
1667
    for (const auto& [addr, socket] : state->d_additionalAddresses) {
388!
1668
      params.emplace_back(TCPAcceptorParam{*state, addr, socket});
×
1669
    }
×
1670
  }
388✔
1671

1672
  if (params.size() == 1) {
388!
1673
    while (true) {
3,067✔
1674
      acceptNewConnection(params.at(0), nullptr);
2,679✔
1675
    }
2,679✔
1676
  }
388✔
1677
  else {
×
1678
    auto acceptCallback = [](int socket, FDMultiplexer::funcparam_t& funcparam) {
×
NEW
1679
      (void)socket;
×
1680
      const auto* acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam);
×
1681
      acceptNewConnection(*acceptorParam, nullptr);
×
1682
    };
×
1683

1684
    auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(params.size()));
×
1685
    for (const auto& param : params) {
×
1686
      mplexer->addReadFD(param.socket, acceptCallback, &param);
×
1687
    }
×
1688

1689
    timeval now{};
×
1690
    while (true) {
×
1691
      mplexer->run(&now, -1);
×
1692
    }
×
1693
  }
×
1694
}
388✔
1695
#endif
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc