• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 11482336920

23 Oct 2024 02:42PM UTC coverage: 64.627% (+0.8%) from 63.846%
11482336920

Pull #14793

github

web-flow
Bump check-spelling/check-spelling from 0.0.22 to 0.0.23

Bumps [check-spelling/check-spelling](https://github.com/check-spelling/check-spelling) from 0.0.22 to 0.0.23.
- [Release notes](https://github.com/check-spelling/check-spelling/releases)
- [Changelog](https://github.com/check-spelling/check-spelling/blob/main/gh-release-downloader)
- [Commits](https://github.com/check-spelling/check-spelling/compare/v0.0.22...v0.0.23)

---
updated-dependencies:
- dependency-name: check-spelling/check-spelling
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Pull Request #14793: Bump check-spelling/check-spelling from 0.0.22 to 0.0.23

37129 of 88222 branches covered (42.09%)

Branch coverage included in aggregate %.

124752 of 162264 relevant lines covered (76.88%)

4563787.68 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.79
/pdns/dnsdistdist/dnsdist-tcp.cc
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22

23
#include <thread>
24
#include <netinet/tcp.h>
25
#include <queue>
26
#include <boost/format.hpp>
27

28
#include "dnsdist.hh"
29
#include "dnsdist-concurrent-connections.hh"
30
#include "dnsdist-dnsparser.hh"
31
#include "dnsdist-ecs.hh"
32
#include "dnsdist-edns.hh"
33
#include "dnsdist-nghttp2-in.hh"
34
#include "dnsdist-proxy-protocol.hh"
35
#include "dnsdist-rings.hh"
36
#include "dnsdist-tcp.hh"
37
#include "dnsdist-tcp-downstream.hh"
38
#include "dnsdist-downstream-connection.hh"
39
#include "dnsdist-tcp-upstream.hh"
40
#include "dnsparser.hh"
41
#include "dolog.hh"
42
#include "gettime.hh"
43
#include "lock.hh"
44
#include "sstuff.hh"
45
#include "tcpiohandler.hh"
46
#include "tcpiohandler-mplexer.hh"
47
#include "threadname.hh"
48

49
/* TCP: the grand design.
50
   We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
51
   An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
52
   we will not go there.
53

54
   In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
55
   This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
56
   to guarantee performance.
57

58
   So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
59
   So whenever an answer comes in, we know where it needs to go.
60

61
   Let's start naively.
62
*/
63

64
std::atomic<uint64_t> g_tcpStatesDumpRequested{0};
65

66
LockGuarded<std::map<ComboAddress, size_t, ComboAddress::addressOnlyLessThan>> dnsdist::IncomingConcurrentTCPConnectionsManager::s_tcpClientsConcurrentConnectionsCount;
67

68
IncomingTCPConnectionState::~IncomingTCPConnectionState()
69
{
2,063✔
70
  dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(d_ci.remote);
2,063✔
71

72
  if (d_ci.cs != nullptr) {
2,063!
73
    timeval now{};
2,063✔
74
    gettimeofday(&now, nullptr);
2,063✔
75

76
    auto diff = now - d_connectionStartTime;
2,063✔
77
    d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000 + diff.tv_usec / 1000);
2,063✔
78
  }
2,063✔
79

80
  // would have been done when the object is destroyed anyway,
81
  // but that way we make sure it's done before the ConnectionInfo is destroyed,
82
  // closing the descriptor, instead of relying on the declaration order of the objects in the class
83
  d_handler.close();
2,063✔
84
}
2,063✔
85

86
dnsdist::Protocol IncomingTCPConnectionState::getProtocol() const
87
{
23,270✔
88
  if (d_ci.cs->dohFrontend) {
23,270✔
89
    return dnsdist::Protocol::DoH;
161✔
90
  }
161✔
91
  if (d_handler.isTLS()) {
23,109✔
92
    return dnsdist::Protocol::DoT;
20,790✔
93
  }
20,790✔
94
  return dnsdist::Protocol::DoTCP;
2,319✔
95
}
23,109✔
96

97
size_t IncomingTCPConnectionState::clearAllDownstreamConnections()
98
{
182✔
99
  return t_downstreamTCPConnectionsManager.clear();
182✔
100
}
182✔
101

102
std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now)
103
{
2,019✔
104
  auto downstream = getOwnedDownstreamConnection(backend, tlvs);
2,019✔
105

106
  if (!downstream) {
2,019✔
107
    /* we don't have a connection to this backend owned yet, let's get one (it might not be a fresh one, though) */
108
    downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(d_threadData.mplexer, backend, now, std::string());
1,987✔
109
    if (backend->d_config.useProxyProtocol) {
1,987✔
110
      registerOwnedDownstreamConnection(downstream);
15✔
111
    }
15✔
112
  }
1,987✔
113

114
  return downstream;
2,019✔
115
}
2,019✔
116

117
static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates);
118

119
TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientState*> tcpAcceptStates) :
120
  d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads)
121
{
318✔
122
  for (size_t idx = 0; idx < maxThreads; idx++) {
3,307✔
123
    addTCPClientThread(tcpAcceptStates);
2,989✔
124
  }
2,989✔
125
}
318✔
126

127
void TCPClientCollection::addTCPClientThread(std::vector<ClientState*>& tcpAcceptStates)
128
{
2,989✔
129
  try {
2,989✔
130
    const auto internalPipeBufferSize = dnsdist::configuration::getImmutableConfiguration().d_tcpInternalPipeBufferSize;
2,989✔
131

132
    auto [queryChannelSender, queryChannelReceiver] = pdns::channel::createObjectQueue<ConnectionInfo>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
2,989✔
133

134
    auto [crossProtocolQueryChannelSender, crossProtocolQueryChannelReceiver] = pdns::channel::createObjectQueue<CrossProtocolQuery>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
2,989✔
135

136
    auto [crossProtocolResponseChannelSender, crossProtocolResponseChannelReceiver] = pdns::channel::createObjectQueue<TCPCrossProtocolResponse>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
2,989✔
137

138
    vinfolog("Adding TCP Client thread");
2,989✔
139

140
    if (d_numthreads >= d_tcpclientthreads.size()) {
2,989!
141
      vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size());
×
142
      return;
×
143
    }
×
144

145
    TCPWorkerThread worker(std::move(queryChannelSender), std::move(crossProtocolQueryChannelSender));
2,989✔
146

147
    try {
2,989✔
148
      std::thread clientThread(tcpClientThread, std::move(queryChannelReceiver), std::move(crossProtocolQueryChannelReceiver), std::move(crossProtocolResponseChannelReceiver), std::move(crossProtocolResponseChannelSender), tcpAcceptStates);
2,989✔
149
      clientThread.detach();
2,989✔
150
    }
2,989✔
151
    catch (const std::runtime_error& e) {
2,989✔
152
      errlog("Error creating a TCP thread: %s", e.what());
×
153
      return;
×
154
    }
×
155

156
    d_tcpclientthreads.at(d_numthreads) = std::move(worker);
2,989✔
157
    ++d_numthreads;
2,989✔
158
  }
2,989✔
159
  catch (const std::exception& e) {
2,989✔
160
    errlog("Error creating TCP worker: %s", e.what());
×
161
  }
×
162
}
2,989✔
163

164
std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
165

166
static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
167
{
25,065✔
168
  IOState result = IOState::Done;
25,065✔
169

170
  while (state->active() && !state->d_queuedResponses.empty()) {
48,023✔
171
    DEBUGLOG("queue size is " << state->d_queuedResponses.size() << ", sending the next one");
22,975✔
172
    TCPResponse resp = std::move(state->d_queuedResponses.front());
22,975✔
173
    state->d_queuedResponses.pop_front();
22,975✔
174
    state->d_state = IncomingTCPConnectionState::State::idle;
22,975✔
175
    result = state->sendResponse(now, std::move(resp));
22,975✔
176
    if (result != IOState::Done) {
22,975✔
177
      return result;
17✔
178
    }
17✔
179
  }
22,975✔
180

181
  state->d_state = IncomingTCPConnectionState::State::idle;
25,048✔
182
  return IOState::Done;
25,048✔
183
}
25,065✔
184

185
void IncomingTCPConnectionState::handleResponseSent(TCPResponse& currentResponse, size_t sentBytes)
186
{
22,969✔
187
  if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) {
22,969✔
188
    return;
451✔
189
  }
451✔
190

191
  --d_currentQueriesCount;
22,518✔
192

193
  const auto& backend = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds;
22,518✔
194
  if (!currentResponse.d_idstate.selfGenerated && backend) {
22,518!
195
    const auto& ids = currentResponse.d_idstate;
2,069✔
196
    double udiff = ids.queryRealTime.udiff();
2,069✔
197
    vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", backend->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), sentBytes, udiff);
2,069✔
198

199
    auto backendProtocol = backend->getProtocol();
2,069✔
200
    if (backendProtocol == dnsdist::Protocol::DoUDP && !currentResponse.d_idstate.forwardedOverUDP) {
2,069✔
201
      backendProtocol = dnsdist::Protocol::DoTCP;
1,672✔
202
    }
1,672✔
203
    ::handleResponseSent(ids, udiff, d_ci.remote, backend->d_config.remote, static_cast<unsigned int>(sentBytes), currentResponse.d_cleartextDH, backendProtocol, true);
2,069✔
204
  }
2,069✔
205
  else {
20,449✔
206
    const auto& ids = currentResponse.d_idstate;
20,449✔
207
    ::handleResponseSent(ids, 0., d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false);
20,449✔
208
  }
20,449✔
209

210
  currentResponse.d_buffer.clear();
22,518✔
211
  currentResponse.d_connection.reset();
22,518✔
212
}
22,518✔
213

214
static void prependSizeToTCPQuery(PacketBuffer& buffer, size_t proxyProtocolPayloadSize)
215
{
2,205✔
216
  if (buffer.size() <= proxyProtocolPayloadSize) {
2,205!
217
    throw std::runtime_error("The payload size is smaller or equal to the buffer size");
×
218
  }
×
219

220
  uint16_t queryLen = proxyProtocolPayloadSize > 0 ? (buffer.size() - proxyProtocolPayloadSize) : buffer.size();
2,205✔
221
  const std::array<uint8_t, 2> sizeBytes{static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256)};
2,205✔
222
  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
223
     that could occur if we had to deal with the size during the processing,
224
     especially alignment issues */
225
  buffer.insert(buffer.begin() + static_cast<PacketBuffer::iterator::difference_type>(proxyProtocolPayloadSize), sizeBytes.begin(), sizeBytes.end());
2,205✔
226
}
2,205✔
227

228
bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now)
229
{
25,043✔
230
  if (d_hadErrors) {
25,043✔
231
    DEBUGLOG("not accepting new queries because we encountered some error during the processing already");
2✔
232
    return false;
2✔
233
  }
2✔
234

235
  // for DoH, this is already handled by the underlying library
236
  if (!d_ci.cs->dohFrontend && d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
25,041✔
237
    DEBUGLOG("not accepting new queries because we already have " << d_currentQueriesCount << " out of " << d_ci.cs->d_maxInFlightQueriesPerConn);
2,156✔
238
    return false;
2,156✔
239
  }
2,156✔
240

241
  const auto& currentConfig = dnsdist::configuration::getCurrentRuntimeConfiguration();
22,885✔
242
  if (currentConfig.d_maxTCPQueriesPerConn != 0 && d_queriesCount > currentConfig.d_maxTCPQueriesPerConn) {
22,885✔
243
    vinfolog("not accepting new queries from %s because it reached the maximum number of queries per conn (%d / %d)", d_ci.remote.toStringWithPort(), d_queriesCount, currentConfig.d_maxTCPQueriesPerConn);
208✔
244
    return false;
208✔
245
  }
208✔
246

247
  if (maxConnectionDurationReached(currentConfig.d_maxTCPConnectionDuration, now)) {
22,677!
248
    vinfolog("not accepting new queries from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
×
249
    return false;
×
250
  }
×
251

252
  return true;
22,677✔
253
}
22,677✔
254

255
void IncomingTCPConnectionState::resetForNewQuery()
256
{
22,677✔
257
  d_buffer.clear();
22,677✔
258
  d_currentPos = 0;
22,677✔
259
  d_querySize = 0;
22,677✔
260
  d_state = State::waitingForQuery;
22,677✔
261
}
22,677✔
262

263
std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs)
264
{
2,019✔
265
  auto connIt = d_ownedConnectionsToBackend.find(backend);
2,019✔
266
  if (connIt == d_ownedConnectionsToBackend.end()) {
2,019✔
267
    DEBUGLOG("no owned connection found for " << backend->getName());
1,987✔
268
    return nullptr;
1,987✔
269
  }
1,987✔
270

271
  for (auto& conn : connIt->second) {
32!
272
    if (conn->canBeReused(true) && conn->matchesTLVs(tlvs)) {
32!
273
      DEBUGLOG("Got one owned connection accepting more for " << backend->getName());
32✔
274
      conn->setReused();
32✔
275
      return conn;
32✔
276
    }
32✔
277
    DEBUGLOG("not accepting more for " << backend->getName());
×
278
  }
×
279

280
  return nullptr;
×
281
}
32✔
282

283
void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn)
284
{
15✔
285
  d_ownedConnectionsToBackend[conn->getDS()].push_front(conn);
15✔
286
}
15✔
287

288
/* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */
289
IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response)
290
{
22,862✔
291
  d_state = State::sendingResponse;
22,862✔
292

293
  const auto responseSize = static_cast<uint16_t>(response.d_buffer.size());
22,862✔
294
  const std::array<uint8_t, 2> sizeBytes{static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
22,862✔
295
  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
296
     that could occur if we had to deal with the size during the processing,
297
     especially alignment issues */
298
  response.d_buffer.insert(response.d_buffer.begin(), sizeBytes.begin(), sizeBytes.end());
22,862✔
299
  d_currentPos = 0;
22,862✔
300
  d_currentResponse = std::move(response);
22,862✔
301

302
  try {
22,862✔
303
    auto iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
22,862✔
304
    if (iostate == IOState::Done) {
22,862✔
305
      DEBUGLOG("response sent from " << __PRETTY_FUNCTION__);
22,842✔
306
      handleResponseSent(d_currentResponse, d_currentResponse.d_buffer.size());
22,842✔
307
      return iostate;
22,842✔
308
    }
22,842✔
309
    d_lastIOBlocked = true;
20✔
310
    DEBUGLOG("partial write");
20✔
311
    return iostate;
20✔
312
  }
22,862✔
313
  catch (const std::exception& e) {
22,862✔
314
    vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what());
4!
315
    DEBUGLOG("Closing TCP client connection: " << e.what());
4✔
316
    ++d_ci.cs->tcpDiedSendingResponse;
4✔
317

318
    terminateClientConnection();
4✔
319

320
    return IOState::Done;
4✔
321
  }
4✔
322
}
22,862✔
323

324
void IncomingTCPConnectionState::terminateClientConnection()
325
{
1,878✔
326
  DEBUGLOG("terminating client connection");
1,878✔
327
  d_queuedResponses.clear();
1,878✔
328
  /* we have already released idle connections that could be reused,
329
     we don't care about the ones still waiting for responses */
330
  for (auto& backend : d_ownedConnectionsToBackend) {
1,878✔
331
    for (auto& conn : backend.second) {
14✔
332
      conn->release(true);
14✔
333
    }
14✔
334
  }
14✔
335
  d_ownedConnectionsToBackend.clear();
1,878✔
336

337
  /* meaning we will no longer be 'active' when the backend
338
     response or timeout comes in */
339
  d_ioState.reset();
1,878✔
340

341
  /* if we do have remaining async descriptors associated with this TLS
342
     connection, we need to defer the destruction of the TLS object until
343
     the engine has reported back, otherwise we have a use-after-free.. */
344
  auto afds = d_handler.getAsyncFDs();
1,878✔
345
  if (afds.empty()) {
1,878!
346
    d_handler.close();
1,878✔
347
  }
1,878✔
348
  else {
×
349
    /* we might already be waiting, but we might also not because sometimes we have already been
350
       notified via the descriptor, not received Async again, but the async job still exists.. */
351
    auto state = shared_from_this();
×
352
    for (const auto desc : afds) {
×
353
      try {
×
354
        state->d_threadData.mplexer->addReadFD(desc, handleAsyncReady, state);
×
355
      }
×
356
      catch (...) {
×
357
      }
×
358
    }
×
359
  }
×
360
}
1,878✔
361

362
void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response, bool fromBackend)
363
{
22,975✔
364
  // queue response
365
  state->d_queuedResponses.emplace_back(std::move(response));
22,975✔
366
  DEBUGLOG("queueing response, state is " << (int)state->d_state << ", queue size is now " << state->d_queuedResponses.size());
22,975✔
367

368
  // when the response comes from a backend, there is a real possibility that we are currently
369
  // idle, and thus not trying to send the response right away would make our ref count go to 0.
370
  // Even if we are waiting for a query, we will not wake up before the new query arrives or a
371
  // timeout occurs
372
  if (state->d_state == State::idle || state->d_state == State::waitingForQuery) {
22,975✔
373
    auto iostate = sendQueuedResponses(state, now);
22,731✔
374

375
    if (iostate == IOState::Done && state->active()) {
22,731✔
376
      if (state->canAcceptNewQueries(now)) {
22,713✔
377
        state->resetForNewQuery();
22,463✔
378
        state->d_state = State::waitingForQuery;
22,463✔
379
        iostate = IOState::NeedRead;
22,463✔
380
      }
22,463✔
381
      else {
250✔
382
        state->d_state = State::idle;
250✔
383
      }
250✔
384
    }
22,713✔
385

386
    // for the same reason we need to update the state right away, nobody will do that for us
387
    if (state->active()) {
22,731✔
388
      state->updateIO(iostate, now);
22,726✔
389
      // if we have not finished reading every available byte, we _need_ to do an actual read
390
      // attempt before waiting for the socket to become readable again, because if there is
391
      // buffered data available the socket might never become readable again.
392
      // This is true as soon as we deal with TLS because TLS records are processed one by
393
      // one and might not match what we see at the application layer, so data might already
394
      // be available in the TLS library's buffers. This is especially true when OpenSSL's
395
      // read-ahead mode is enabled because then it buffers even more than one SSL record
396
      // for performance reasons.
397
      if (fromBackend && !state->d_lastIOBlocked) {
22,726✔
398
        state->handleIO();
2,156✔
399
      }
2,156✔
400
    }
22,726✔
401
  }
22,731✔
402
}
22,975✔
403

404
void IncomingTCPConnectionState::handleAsyncReady([[maybe_unused]] int desc, FDMultiplexer::funcparam_t& param)
405
{
×
406
  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
×
407

408
  /* If we are here, the async jobs for this SSL* are finished
409
     so we should be able to remove all FDs */
410
  auto afds = state->d_handler.getAsyncFDs();
×
411
  for (const auto afd : afds) {
×
412
    try {
×
413
      state->d_threadData.mplexer->removeReadFD(afd);
×
414
    }
×
415
    catch (...) {
×
416
    }
×
417
  }
×
418

419
  if (state->active()) {
×
420
    /* and now we restart our own I/O state machine */
421
    state->handleIO();
×
422
  }
×
423
  else {
×
424
    /* we were only waiting for the engine to come back,
425
       to prevent a use-after-free */
426
    state->d_handler.close();
×
427
  }
×
428
}
×
429

430
void IncomingTCPConnectionState::updateIOForAsync(std::shared_ptr<IncomingTCPConnectionState>& conn)
431
{
×
432
  auto fds = conn->d_handler.getAsyncFDs();
×
433
  for (const auto desc : fds) {
×
434
    conn->d_threadData.mplexer->addReadFD(desc, handleAsyncReady, conn);
×
435
  }
×
436
  conn->d_ioState->update(IOState::Done, handleIOCallback, conn);
×
437
}
×
438

439
void IncomingTCPConnectionState::updateIO(IOState newState, const struct timeval& now)
440
{
45,263✔
441
  auto sharedPtrToConn = shared_from_this();
45,263✔
442
  if (newState == IOState::Async) {
45,263!
443
    updateIOForAsync(sharedPtrToConn);
×
444
    return;
×
445
  }
×
446

447
  d_ioState->update(newState, handleIOCallback, sharedPtrToConn, newState == IOState::NeedWrite ? getClientWriteTTD(now) : getClientReadTTD(now));
45,263✔
448
}
45,263✔
449

450
/* called from the backend code when a new response has been received */
451
void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response)
452
{
2,307✔
453
  if (std::this_thread::get_id() != d_creatorThreadID) {
2,307✔
454
    handleCrossProtocolResponse(now, std::move(response));
124✔
455
    return;
124✔
456
  }
124✔
457

458
  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
2,183✔
459

460
  if (!response.isAsync() && response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) {
2,183!
461
    // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool as no one else will be able to use it anyway
462
    if (!response.d_connection->willBeReusable(true)) {
40!
463
      // if it can't be reused even by us, well
464
      const auto connIt = state->d_ownedConnectionsToBackend.find(response.d_connection->getDS());
×
465
      if (connIt != state->d_ownedConnectionsToBackend.end()) {
×
466
        auto& list = connIt->second;
×
467

468
        for (auto it = list.begin(); it != list.end(); ++it) {
×
469
          if (*it == response.d_connection) {
×
470
            try {
×
471
              response.d_connection->release(true);
×
472
            }
×
473
            catch (const std::exception& e) {
×
474
              vinfolog("Error releasing connection: %s", e.what());
×
475
            }
×
476
            list.erase(it);
×
477
            break;
×
478
          }
×
479
        }
×
480
      }
×
481
    }
×
482
  }
40✔
483

484
  if (response.d_buffer.size() < sizeof(dnsheader)) {
2,183✔
485
    state->terminateClientConnection();
2✔
486
    return;
2✔
487
  }
2✔
488

489
  if (!response.isAsync()) {
2,181✔
490
    try {
2,097✔
491
      auto& ids = response.d_idstate;
2,097✔
492
      std::shared_ptr<DownstreamState> backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
2,097!
493
      if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend, dnsdist::configuration::getCurrentRuntimeConfiguration().d_allowEmptyResponse)) {
2,097!
494
        state->terminateClientConnection();
3✔
495
        return;
3✔
496
      }
3✔
497

498
      if (backend != nullptr) {
2,094!
499
        ++backend->responses;
2,094✔
500
      }
2,094✔
501

502
      DNSResponse dnsResponse(ids, response.d_buffer, backend);
2,094✔
503
      dnsResponse.d_incomingTCPState = state;
2,094✔
504

505
      memcpy(&response.d_cleartextDH, dnsResponse.getHeader().get(), sizeof(response.d_cleartextDH));
2,094✔
506

507
      if (!processResponse(response.d_buffer, dnsResponse, false)) {
2,094✔
508
        state->terminateClientConnection();
6✔
509
        return;
6✔
510
      }
6✔
511

512
      if (dnsResponse.isAsynchronous()) {
2,088✔
513
        /* we are done for now */
514
        return;
79✔
515
      }
79✔
516
    }
2,088✔
517
    catch (const std::exception& e) {
2,097✔
518
      vinfolog("Unexpected exception while handling response from backend: %s", e.what());
4!
519
      state->terminateClientConnection();
4✔
520
      return;
4✔
521
    }
4✔
522
  }
2,097✔
523

524
  ++dnsdist::metrics::g_stats.responses;
2,089✔
525
  ++state->d_ci.cs->responses;
2,089✔
526

527
  queueResponse(state, now, std::move(response), true);
2,089✔
528
}
2,089✔
529

530
struct TCPCrossProtocolResponse
531
{
532
  TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now) :
533
    d_response(std::move(response)), d_state(state), d_now(now)
534
  {
257✔
535
  }
257✔
536
  TCPCrossProtocolResponse(const TCPCrossProtocolResponse&) = delete;
537
  TCPCrossProtocolResponse& operator=(const TCPCrossProtocolResponse&) = delete;
538
  TCPCrossProtocolResponse(TCPCrossProtocolResponse&&) = delete;
539
  TCPCrossProtocolResponse& operator=(TCPCrossProtocolResponse&&) = delete;
540
  ~TCPCrossProtocolResponse() = default;
257✔
541

542
  TCPResponse d_response;
543
  std::shared_ptr<IncomingTCPConnectionState> d_state;
544
  struct timeval d_now;
545
};
546

547
class TCPCrossProtocolQuery : public CrossProtocolQuery
548
{
549
public:
550
  TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> backend, std::shared_ptr<IncomingTCPConnectionState> sender) :
551
    CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), backend), d_sender(std::move(sender))
552
  {
216✔
553
  }
216✔
554
  TCPCrossProtocolQuery(const TCPCrossProtocolQuery&) = delete;
555
  TCPCrossProtocolQuery& operator=(const TCPCrossProtocolQuery&) = delete;
556
  TCPCrossProtocolQuery(TCPCrossProtocolQuery&&) = delete;
557
  TCPCrossProtocolQuery& operator=(TCPCrossProtocolQuery&&) = delete;
558
  ~TCPCrossProtocolQuery() override = default;
216✔
559

560
  std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
561
  {
194✔
562
    return d_sender;
194✔
563
  }
194✔
564

565
  DNSQuestion getDQ() override
566
  {
168✔
567
    auto& ids = query.d_idstate;
168✔
568
    DNSQuestion dnsQuestion(ids, query.d_buffer);
168✔
569
    dnsQuestion.d_incomingTCPState = d_sender;
168✔
570
    return dnsQuestion;
168✔
571
  }
168✔
572

573
  DNSResponse getDR() override
574
  {
72✔
575
    auto& ids = query.d_idstate;
72✔
576
    DNSResponse dnsResponse(ids, query.d_buffer, downstream);
72✔
577
    dnsResponse.d_incomingTCPState = d_sender;
72✔
578
    return dnsResponse;
72✔
579
  }
72✔
580

581
private:
582
  std::shared_ptr<IncomingTCPConnectionState> d_sender;
583
};
584

585
std::unique_ptr<CrossProtocolQuery> IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& backend)
586
{
4✔
587
  return std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(state), backend, shared_from_this());
4✔
588
}
4✔
589

590
std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion)
591
{
187✔
592
  auto state = dnsQuestion.getIncomingTCPState();
187✔
593
  if (!state) {
187!
594
    throw std::runtime_error("Trying to create a TCP cross protocol query without a valid TCP state");
×
595
  }
×
596

597
  dnsQuestion.ids.origID = dnsQuestion.getHeader()->id;
187✔
598
  return std::make_unique<TCPCrossProtocolQuery>(std::move(dnsQuestion.getMutableData()), std::move(dnsQuestion.ids), nullptr, std::move(state));
187✔
599
}
187✔
600

601
void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response)
602
{
257✔
603
  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
257✔
604
  try {
257✔
605
    auto ptr = std::make_unique<TCPCrossProtocolResponse>(std::move(response), state, now);
257✔
606
    if (!state->d_threadData.crossProtocolResponseSender.send(std::move(ptr))) {
257!
607
      ++dnsdist::metrics::g_stats.tcpCrossProtocolResponsePipeFull;
×
608
      vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
×
609
    }
×
610
  }
257✔
611
  catch (const std::exception& e) {
257✔
612
    vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
×
613
  }
×
614
}
257✔
615

616
IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::handleQuery(PacketBuffer&& queryIn, const struct timeval& now, std::optional<int32_t> streamID)
617
{
22,682✔
618
  auto query = std::move(queryIn);
22,682✔
619
  if (query.size() < sizeof(dnsheader)) {
22,682✔
620
    ++dnsdist::metrics::g_stats.nonCompliantQueries;
2✔
621
    ++d_ci.cs->nonCompliantQueries;
2✔
622
    return QueryProcessingResult::TooSmall;
2✔
623
  }
2✔
624

625
  ++d_queriesCount;
22,680✔
626
  ++d_ci.cs->queries;
22,680✔
627
  ++dnsdist::metrics::g_stats.queries;
22,680✔
628

629
  if (d_handler.isTLS()) {
22,680✔
630
    auto tlsVersion = d_handler.getTLSVersion();
20,894✔
631
    switch (tlsVersion) {
20,894✔
632
    case LibsslTLSVersion::TLS10:
×
633
      ++d_ci.cs->tls10queries;
×
634
      break;
×
635
    case LibsslTLSVersion::TLS11:
×
636
      ++d_ci.cs->tls11queries;
×
637
      break;
×
638
    case LibsslTLSVersion::TLS12:
9✔
639
      ++d_ci.cs->tls12queries;
9✔
640
      break;
9✔
641
    case LibsslTLSVersion::TLS13:
20,885✔
642
      ++d_ci.cs->tls13queries;
20,885✔
643
      break;
20,885✔
644
    default:
×
645
      ++d_ci.cs->tlsUnknownqueries;
×
646
    }
20,894✔
647
  }
20,894✔
648

649
  auto state = shared_from_this();
22,680✔
650
  InternalQueryState ids;
22,680✔
651
  ids.origDest = d_proxiedDestination;
22,680✔
652
  ids.origRemote = d_proxiedRemote;
22,680✔
653
  ids.cs = d_ci.cs;
22,680✔
654
  ids.queryRealTime.start();
22,680✔
655
  if (streamID) {
22,680✔
656
    ids.d_streamID = *streamID;
126✔
657
  }
126✔
658

659
  auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true);
22,680✔
660
  if (dnsCryptResponse) {
22,680!
661
    TCPResponse response;
×
662
    d_state = State::idle;
×
663
    ++d_currentQueriesCount;
×
664
    queueResponse(state, now, std::move(response), false);
×
665
    return QueryProcessingResult::SelfAnswered;
×
666
  }
×
667

668
  {
22,680✔
669
    /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
670
    const dnsheader_aligned dnsHeader(query.data());
22,680✔
671
    if (!checkQueryHeaders(*dnsHeader, *d_ci.cs)) {
22,680✔
672
      return QueryProcessingResult::InvalidHeaders;
4✔
673
    }
4✔
674

675
    if (dnsHeader->qdcount == 0) {
22,676✔
676
      TCPResponse response;
2✔
677
      auto queryID = dnsHeader->id;
2✔
678
      dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
2✔
679
        header.rcode = RCode::NotImp;
2✔
680
        header.qr = true;
2✔
681
        return true;
2✔
682
      });
2✔
683
      response.d_idstate = std::move(ids);
2✔
684
      response.d_idstate.origID = queryID;
2✔
685
      response.d_idstate.selfGenerated = true;
2✔
686
      response.d_buffer = std::move(query);
2✔
687
      d_state = State::idle;
2✔
688
      ++d_currentQueriesCount;
2✔
689
      queueResponse(state, now, std::move(response), false);
2✔
690
      return QueryProcessingResult::SelfAnswered;
2✔
691
    }
2✔
692
  }
22,676✔
693

694
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast
695
  ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), static_cast<int>(query.size()), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
22,674✔
696
  ids.protocol = getProtocol();
22,674✔
697
  if (ids.dnsCryptQuery) {
22,674✔
698
    ids.protocol = dnsdist::Protocol::DNSCryptTCP;
8✔
699
  }
8✔
700

701
  DNSQuestion dnsQuestion(ids, query);
22,674✔
702
  dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) {
22,674✔
703
    const uint16_t* flags = getFlagsFromDNSHeader(&header);
22,673✔
704
    ids.origFlags = *flags;
22,673✔
705
    return true;
22,673✔
706
  });
22,673✔
707
  dnsQuestion.d_incomingTCPState = state;
22,674✔
708
  dnsQuestion.sni = d_handler.getServerNameIndication();
22,674✔
709

710
  if (d_proxyProtocolValues) {
22,674✔
711
    /* we need to copy them, because the next queries received on that connection will
712
       need to get the _unaltered_ values */
713
    dnsQuestion.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*d_proxyProtocolValues);
34✔
714
  }
34✔
715

716
  if (dnsQuestion.ids.qtype == QType::AXFR || dnsQuestion.ids.qtype == QType::IXFR) {
22,674✔
717
    dnsQuestion.ids.skipCache = true;
27✔
718
  }
27✔
719

720
  if (forwardViaUDPFirst()) {
22,674✔
721
    // if there was no EDNS, we add it with a large buffer size
722
    // so we can use UDP to talk to the backend.
723
    const dnsheader_aligned dnsHeader(query.data());
123✔
724
    if (dnsHeader->arcount == 0U) {
123✔
725
      if (addEDNS(query, 4096, false, 4096, 0)) {
117!
726
        dnsQuestion.ids.ednsAdded = true;
117✔
727
      }
117✔
728
    }
117✔
729
  }
123✔
730

731
  if (streamID) {
22,674✔
732
    auto unit = getDOHUnit(*streamID);
123✔
733
    if (unit) {
123!
734
      dnsQuestion.ids.du = std::move(unit);
123✔
735
    }
123✔
736
  }
123✔
737

738
  std::shared_ptr<DownstreamState> backend;
22,674✔
739
  auto result = processQuery(dnsQuestion, backend);
22,674✔
740

741
  if (result == ProcessQueryResult::Asynchronous) {
22,674✔
742
    /* we are done for now */
743
    ++d_currentQueriesCount;
108✔
744
    return QueryProcessingResult::Asynchronous;
108✔
745
  }
108✔
746

747
  if (streamID) {
22,566✔
748
    restoreDOHUnit(std::move(dnsQuestion.ids.du));
87✔
749
  }
87✔
750

751
  if (result == ProcessQueryResult::Drop) {
22,566✔
752
    return QueryProcessingResult::Dropped;
36✔
753
  }
36✔
754

755
  // the buffer might have been invalidated by now
756
  uint16_t queryID{0};
22,530✔
757
  {
22,530✔
758
    const auto dnsHeader = dnsQuestion.getHeader();
22,530✔
759
    queryID = dnsHeader->id;
22,530✔
760
  }
22,530✔
761

762
  if (result == ProcessQueryResult::SendAnswer) {
22,530✔
763
    TCPResponse response;
20,435✔
764
    {
20,435✔
765
      const auto dnsHeader = dnsQuestion.getHeader();
20,435✔
766
      memcpy(&response.d_cleartextDH, dnsHeader.get(), sizeof(response.d_cleartextDH));
20,435✔
767
    }
20,435✔
768
    response.d_idstate = std::move(ids);
20,435✔
769
    response.d_idstate.origID = queryID;
20,435✔
770
    response.d_idstate.selfGenerated = true;
20,435✔
771
    response.d_idstate.cs = d_ci.cs;
20,435✔
772
    response.d_buffer = std::move(query);
20,435✔
773

774
    d_state = State::idle;
20,435✔
775
    ++d_currentQueriesCount;
20,435✔
776
    queueResponse(state, now, std::move(response), false);
20,435✔
777
    return QueryProcessingResult::SelfAnswered;
20,435✔
778
  }
20,435✔
779

780
  if (result != ProcessQueryResult::PassToBackend || backend == nullptr) {
2,095!
781
    return QueryProcessingResult::NoBackend;
×
782
  }
×
783

784
  dnsQuestion.ids.origID = queryID;
2,095✔
785

786
  ++d_currentQueriesCount;
2,095✔
787

788
  std::string proxyProtocolPayload;
2,095✔
789
  if (backend->isDoH()) {
2,095✔
790
    vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), backend->getNameWithAddr());
25✔
791

792
    /* we need to do this _before_ creating the cross protocol query because
793
       after that the buffer will have been moved */
794
    if (backend->d_config.useProxyProtocol) {
25✔
795
      proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion);
1✔
796
    }
1✔
797

798
    auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(ids), backend, state);
25✔
799
    cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
25✔
800

801
    backend->passCrossProtocolQuery(std::move(cpq));
25✔
802
    return QueryProcessingResult::Forwarded;
25✔
803
  }
25✔
804
  if (!backend->isTCPOnly() && forwardViaUDPFirst()) {
2,070✔
805
    if (streamID) {
49!
806
      auto unit = getDOHUnit(*streamID);
49✔
807
      if (unit) {
49!
808
        dnsQuestion.ids.du = std::move(unit);
49✔
809
      }
49✔
810
    }
49✔
811
    if (assignOutgoingUDPQueryToBackend(backend, queryID, dnsQuestion, query)) {
49✔
812
      return QueryProcessingResult::Forwarded;
48✔
813
    }
48✔
814
    restoreDOHUnit(std::move(dnsQuestion.ids.du));
1✔
815
    // fallback to the normal flow
816
  }
1✔
817

818
  prependSizeToTCPQuery(query, 0);
2,022✔
819

820
  auto downstreamConnection = getDownstreamConnection(backend, dnsQuestion.proxyProtocolValues, now);
2,022✔
821

822
  if (backend->d_config.useProxyProtocol) {
2,022✔
823
    /* if we ever sent a TLV over a connection, we can never go back */
824
    if (!d_proxyProtocolPayloadHasTLV) {
47✔
825
      d_proxyProtocolPayloadHasTLV = dnsQuestion.proxyProtocolValues && !dnsQuestion.proxyProtocolValues->empty();
32!
826
    }
32✔
827

828
    proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion);
47✔
829
  }
47✔
830

831
  if (dnsQuestion.proxyProtocolValues) {
2,022✔
832
    downstreamConnection->setProxyProtocolValuesSent(std::move(dnsQuestion.proxyProtocolValues));
21✔
833
  }
21✔
834

835
  TCPQuery tcpquery(std::move(query), std::move(ids));
2,022✔
836
  tcpquery.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
2,022✔
837

838
  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), backend->getNameWithAddr());
2,022✔
839
  std::shared_ptr<TCPQuerySender> incoming = state;
2,022✔
840
  downstreamConnection->queueQuery(incoming, std::move(tcpquery));
2,022✔
841
  return QueryProcessingResult::Forwarded;
2,022✔
842
}
2,070✔
843

844
void IncomingTCPConnectionState::handleIOCallback(int desc, FDMultiplexer::funcparam_t& param)
845
{
2,011✔
846
  auto conn = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
2,011✔
847
  if (desc != conn->d_handler.getDescriptor()) {
2,011!
848
    // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay): __PRETTY_FUNCTION__ is fine
849
    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(desc) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor()));
×
850
  }
×
851

852
  conn->handleIO();
2,011✔
853
}
2,011✔
854

855
void IncomingTCPConnectionState::handleHandshakeDone(const struct timeval& now)
856
{
2,078✔
857
  if (d_handler.isTLS()) {
2,078✔
858
    if (!d_handler.hasTLSSessionBeenResumed()) {
370✔
859
      ++d_ci.cs->tlsNewSessions;
345✔
860
    }
345✔
861
    else {
25✔
862
      ++d_ci.cs->tlsResumptions;
25✔
863
    }
25✔
864
    if (d_handler.getResumedFromInactiveTicketKey()) {
370✔
865
      ++d_ci.cs->tlsInactiveTicketKey;
8✔
866
    }
8✔
867
    if (d_handler.getUnknownTicketKey()) {
370✔
868
      ++d_ci.cs->tlsUnknownTicketKey;
6✔
869
    }
6✔
870
  }
370✔
871

872
  d_handshakeDoneTime = now;
2,078✔
873
}
2,078✔
874

875
IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::handleProxyProtocolPayload()
876
{
19✔
877
  do {
32✔
878
    DEBUGLOG("reading proxy protocol header");
32✔
879
    auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed, false, isProxyPayloadOutsideTLS());
32✔
880
    if (iostate == IOState::Done) {
32✔
881
      d_buffer.resize(d_currentPos);
27✔
882
      ssize_t remaining = isProxyHeaderComplete(d_buffer);
27✔
883
      if (remaining == 0) {
27✔
884
        vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", d_ci.remote.toStringWithPort());
3!
885
        ++dnsdist::metrics::g_stats.proxyProtocolInvalid;
3✔
886
        return ProxyProtocolResult::Error;
3✔
887
      }
3✔
888
      if (remaining < 0) {
24✔
889
        d_proxyProtocolNeed += -remaining;
13✔
890
        d_buffer.resize(d_currentPos + d_proxyProtocolNeed);
13✔
891
        /* we need to keep reading, since we might have buffered data */
892
      }
13✔
893
      else {
11✔
894
        /* proxy header received */
895
        std::vector<ProxyProtocolValue> proxyProtocolValues;
11✔
896
        if (!handleProxyProtocol(d_ci.remote, true, dnsdist::configuration::getCurrentRuntimeConfiguration().d_ACL, d_buffer, d_proxiedRemote, d_proxiedDestination, proxyProtocolValues)) {
11!
897
          vinfolog("Error handling the Proxy Protocol received from TCP client %s", d_ci.remote.toStringWithPort());
×
898
          return ProxyProtocolResult::Error;
×
899
        }
×
900

901
        if (!proxyProtocolValues.empty()) {
11✔
902
          d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
9✔
903
        }
9✔
904

905
        d_currentPos = 0;
11✔
906
        d_proxyProtocolNeed = 0;
11✔
907
        d_buffer.clear();
11✔
908
        return ProxyProtocolResult::Done;
11✔
909
      }
11✔
910
    }
24✔
911
    else {
5✔
912
      d_lastIOBlocked = true;
5✔
913
    }
5✔
914
  } while (active() && !d_lastIOBlocked);
32✔
915

916
  return ProxyProtocolResult::Reading;
5✔
917
}
19✔
918

919
IOState IncomingTCPConnectionState::handleHandshake(const struct timeval& now)
920
{
2,045✔
921
  DEBUGLOG("doing handshake");
2,045✔
922
  auto iostate = d_handler.tryHandshake();
2,045✔
923
  if (iostate == IOState::Done) {
2,045✔
924
    DEBUGLOG("handshake done");
1,919✔
925
    handleHandshakeDone(now);
1,919✔
926

927
    if (d_ci.cs != nullptr && d_ci.cs->d_enableProxyProtocol && !isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
1,919!
928
      d_state = State::readingProxyProtocolHeader;
16✔
929
      d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
16✔
930
      d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
16✔
931
    }
16✔
932
    else {
1,903✔
933
      d_state = State::readingQuerySize;
1,903✔
934
    }
1,903✔
935
  }
1,919✔
936
  else {
126✔
937
    d_lastIOBlocked = true;
126✔
938
  }
126✔
939

940
  return iostate;
2,045✔
941
}
2,045✔
942

943
IOState IncomingTCPConnectionState::handleIncomingQueryReceived(const struct timeval& now)
944
{
22,554✔
945
  DEBUGLOG("query received");
22,554✔
946
  d_buffer.resize(d_querySize);
22,554✔
947

948
  d_state = State::idle;
22,554✔
949
  auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt);
22,554✔
950
  switch (processingResult) {
22,554✔
951
  case QueryProcessingResult::TooSmall:
×
952
    /* fall-through */
953
  case QueryProcessingResult::InvalidHeaders:
3✔
954
    /* fall-through */
955
  case QueryProcessingResult::Dropped:
36✔
956
    /* fall-through */
957
  case QueryProcessingResult::NoBackend:
36!
958
    terminateClientConnection();
36✔
959
    ;
36✔
960
  default:
22,550✔
961
    break;
22,550✔
962
  }
22,554✔
963

964
  /* the state might have been updated in the meantime, we don't want to override it
965
     in that case */
966
  if (active() && d_state != State::idle) {
22,550✔
967
    if (d_ioState->isWaitingForRead()) {
20,412✔
968
      return IOState::NeedRead;
20,408✔
969
    }
20,408✔
970
    if (d_ioState->isWaitingForWrite()) {
4!
971
      return IOState::NeedWrite;
4✔
972
    }
4✔
973
    return IOState::Done;
×
974
  }
4✔
975
  return IOState::Done;
2,138✔
976
};
22,550✔
977

978
void IncomingTCPConnectionState::handleExceptionDuringIO(const std::exception& exp)
979
{
1,780✔
980
  if (d_state == State::idle || d_state == State::waitingForQuery) {
1,780✔
981
    /* no need to increase any counters in that case, the client is simply done with us */
982
  }
1,442✔
983
  else if (d_state == State::doingHandshake || d_state == State::readingProxyProtocolHeader || d_state == State::waitingForQuery || d_state == State::readingQuerySize || d_state == State::readingQuery) {
338!
984
    ++d_ci.cs->tcpDiedReadingQuery;
338✔
985
  }
338✔
986
  else if (d_state == State::sendingResponse) {
×
987
    /* unlikely to happen here, the exception should be handled in sendResponse() */
988
    ++d_ci.cs->tcpDiedSendingResponse;
×
989
  }
×
990

991
  if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) {
1,780!
992
    DEBUGLOG("Got an exception while handling TCP query: " << exp.what());
338✔
993
    vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (d_ioState->isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), exp.what());
338!
994
  }
338✔
995
  else {
1,442✔
996
    vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), exp.what());
1,442✔
997
    DEBUGLOG("Closing TCP client connection: " << exp.what());
1,442✔
998
  }
1,442✔
999
  /* remove this FD from the IO multiplexer */
1000
  terminateClientConnection();
1,780✔
1001
}
1,780✔
1002

1003
bool IncomingTCPConnectionState::readIncomingQuery(const timeval& now, IOState& iostate)
1004
{
26,220✔
1005
  if (!d_lastIOBlocked && (d_state == State::waitingForQuery || d_state == State::readingQuerySize)) {
26,220!
1006
    DEBUGLOG("reading query size");
26,069✔
1007
    d_buffer.resize(sizeof(uint16_t));
26,069✔
1008
    iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t));
26,069✔
1009
    if (d_currentPos > 0) {
26,069✔
1010
      /* if we got at least one byte, we can't go around sending responses */
1011
      d_state = State::readingQuerySize;
22,565✔
1012
    }
22,565✔
1013

1014
    if (iostate == IOState::Done) {
26,069✔
1015
      DEBUGLOG("query size received");
22,561✔
1016
      d_state = State::readingQuery;
22,561✔
1017
      d_querySizeReadTime = now;
22,561✔
1018
      if (d_queriesCount == 0) {
22,561✔
1019
        d_firstQuerySizeReadTime = now;
1,579✔
1020
      }
1,579✔
1021
      d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1);
22,561✔
1022
      if (d_querySize < sizeof(dnsheader)) {
22,561✔
1023
        /* go away */
1024
        terminateClientConnection();
2✔
1025
        return true;
2✔
1026
      }
2✔
1027

1028
      d_buffer.resize(d_querySize);
22,559✔
1029
      d_currentPos = 0;
22,559✔
1030
    }
22,559✔
1031
    else {
3,508✔
1032
      d_lastIOBlocked = true;
3,508✔
1033
    }
3,508✔
1034
  }
26,069✔
1035

1036
  if (!d_lastIOBlocked && d_state == State::readingQuery) {
26,218!
1037
    DEBUGLOG("reading query");
22,710✔
1038
    iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize);
22,710✔
1039
    if (iostate == IOState::Done) {
22,710✔
1040
      iostate = handleIncomingQueryReceived(now);
22,554✔
1041
    }
22,554✔
1042
    else {
156✔
1043
      d_lastIOBlocked = true;
156✔
1044
    }
156✔
1045
  }
22,710✔
1046

1047
  return false;
26,218✔
1048
}
26,220✔
1049

1050
void IncomingTCPConnectionState::handleIO()
1051
{
6,009✔
1052
  // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
1053
  // even though the underlying socket is not ready, so we need to actually ask for the data first
1054
  IOState iostate = IOState::Done;
6,009✔
1055
  timeval now{};
6,009✔
1056
  gettimeofday(&now, nullptr);
6,009✔
1057

1058
  do {
26,619✔
1059
    iostate = IOState::Done;
26,619✔
1060
    IOStateGuard ioGuard(d_ioState);
26,619✔
1061

1062
    if (maxConnectionDurationReached(dnsdist::configuration::getCurrentRuntimeConfiguration().d_maxTCPConnectionDuration, now)) {
26,619✔
1063
      vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
1!
1064
      // will be handled by the ioGuard
1065
      // handleNewIOState(state, IOState::Done, fd, handleIOCallback);
1066
      return;
1✔
1067
    }
1✔
1068

1069
    d_lastIOBlocked = false;
26,618✔
1070

1071
    try {
26,618✔
1072
      if (d_state == State::starting) {
26,618✔
1073
        if (d_ci.cs != nullptr && d_ci.cs->d_enableProxyProtocol && isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
1,920!
1074
          d_state = State::readingProxyProtocolHeader;
1✔
1075
          d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
1✔
1076
          d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
1✔
1077
        }
1✔
1078
        else {
1,919✔
1079
          d_state = State::doingHandshake;
1,919✔
1080
        }
1,919✔
1081
      }
1,920✔
1082

1083
      if (d_state == State::doingHandshake) {
26,618✔
1084
        iostate = handleHandshake(now);
2,044✔
1085
      }
2,044✔
1086

1087
      if (!d_lastIOBlocked && d_state == State::readingProxyProtocolHeader) {
26,618✔
1088
        auto status = handleProxyProtocolPayload();
17✔
1089
        if (status == ProxyProtocolResult::Done) {
17✔
1090
          d_buffer.resize(sizeof(uint16_t));
9✔
1091

1092
          if (isProxyPayloadOutsideTLS()) {
9✔
1093
            d_state = State::doingHandshake;
1✔
1094
            iostate = handleHandshake(now);
1✔
1095
          }
1✔
1096
          else {
8✔
1097
            d_state = State::readingQuerySize;
8✔
1098
          }
8✔
1099
        }
9✔
1100
        else if (status == ProxyProtocolResult::Error) {
8✔
1101
          iostate = IOState::Done;
3✔
1102
        }
3✔
1103
        else {
5✔
1104
          iostate = IOState::NeedRead;
5✔
1105
        }
5✔
1106
      }
17✔
1107

1108
      if (!d_lastIOBlocked && (d_state == State::waitingForQuery || d_state == State::readingQuerySize || d_state == State::readingQuery)) {
26,618✔
1109
        if (readIncomingQuery(now, iostate)) {
26,220✔
1110
          return;
2✔
1111
        }
2✔
1112
      }
26,220✔
1113

1114
      if (!d_lastIOBlocked && d_state == State::sendingResponse) {
26,616✔
1115
        DEBUGLOG("sending response");
14✔
1116
        iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
14✔
1117
        if (iostate == IOState::Done) {
14!
1118
          DEBUGLOG("response sent from " << __PRETTY_FUNCTION__);
14✔
1119
          handleResponseSent(d_currentResponse, d_currentResponse.d_buffer.size());
14✔
1120
          d_state = State::idle;
14✔
1121
        }
14✔
1122
        else {
×
1123
          d_lastIOBlocked = true;
×
1124
        }
×
1125
      }
14✔
1126

1127
      if (active() && !d_lastIOBlocked && iostate == IOState::Done && (d_state == State::idle || d_state == State::waitingForQuery)) {
26,616!
1128
        // try sending queued responses
1129
        DEBUGLOG("send responses, if any");
2,334✔
1130
        auto state = shared_from_this();
2,334✔
1131
        iostate = sendQueuedResponses(state, now);
2,334✔
1132

1133
        if (!d_lastIOBlocked && active() && iostate == IOState::Done) {
2,334!
1134
          // if the query has been passed to a backend, or dropped, and the responses have been sent,
1135
          // we can start reading again
1136
          if (canAcceptNewQueries(now)) {
2,330✔
1137
            resetForNewQuery();
214✔
1138
            iostate = IOState::NeedRead;
214✔
1139
          }
214✔
1140
          else {
2,116✔
1141
            d_state = State::idle;
2,116✔
1142
            iostate = IOState::Done;
2,116✔
1143
          }
2,116✔
1144
        }
2,330✔
1145
      }
2,334✔
1146

1147
      if (d_state != State::idle && d_state != State::doingHandshake && d_state != State::readingProxyProtocolHeader && d_state != State::waitingForQuery && d_state != State::readingQuerySize && d_state != State::readingQuery && d_state != State::sendingResponse) {
26,616!
1148
        vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(d_state));
×
1149
      }
×
1150
    }
26,616✔
1151
    catch (const std::exception& exp) {
26,618✔
1152
      /* most likely an EOF because the other end closed the connection,
1153
         but it might also be a real IO error or something else.
1154
         Let's just drop the connection
1155
      */
1156
      handleExceptionDuringIO(exp);
1,780✔
1157
    }
1,780✔
1158

1159
    if (!active()) {
26,616✔
1160
      DEBUGLOG("state is no longer active");
1,848✔
1161
      return;
1,848✔
1162
    }
1,848✔
1163

1164
    auto sharedPtrToConn = shared_from_this();
24,768✔
1165
    if (iostate == IOState::Done) {
24,768✔
1166
      d_ioState->update(iostate, handleIOCallback, sharedPtrToConn);
2,119✔
1167
    }
2,119✔
1168
    else {
22,649✔
1169
      updateIO(iostate, now);
22,649✔
1170
    }
22,649✔
1171
    ioGuard.release();
24,768✔
1172
  } while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !d_lastIOBlocked);
24,768✔
1173
}
6,009✔
1174

1175
void IncomingTCPConnectionState::notifyIOError(const struct timeval& now, TCPResponse&& response)
1176
{
61✔
1177
  if (std::this_thread::get_id() != d_creatorThreadID) {
61✔
1178
    /* empty buffer will signal an IO error */
1179
    response.d_buffer.clear();
18✔
1180
    handleCrossProtocolResponse(now, std::move(response));
18✔
1181
    return;
18✔
1182
  }
18✔
1183

1184
  auto sharedPtrToConn = shared_from_this();
43✔
1185
  --sharedPtrToConn->d_currentQueriesCount;
43✔
1186
  sharedPtrToConn->d_hadErrors = true;
43✔
1187

1188
  if (sharedPtrToConn->d_state == State::sendingResponse) {
43✔
1189
    /* if we have responses to send, let's do that first */
1190
  }
2✔
1191
  else if (!sharedPtrToConn->d_queuedResponses.empty()) {
41!
1192
    /* stop reading and send what we have */
1193
    try {
×
1194
      auto iostate = sendQueuedResponses(sharedPtrToConn, now);
×
1195

1196
      if (sharedPtrToConn->active() && iostate != IOState::Done) {
×
1197
        // we need to update the state right away, nobody will do that for us
1198
        updateIO(iostate, now);
×
1199
      }
×
1200
    }
×
1201
    catch (const std::exception& e) {
×
1202
      vinfolog("Exception in notifyIOError: %s", e.what());
×
1203
    }
×
1204
  }
×
1205
  else {
41✔
1206
    // the backend code already tried to reconnect if it was possible
1207
    sharedPtrToConn->terminateClientConnection();
41✔
1208
  }
41✔
1209
}
43✔
1210

1211
static bool processXFRResponse(PacketBuffer& response, DNSResponse& dnsResponse)
1212
{
449✔
1213
  const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains;
449✔
1214
  const auto& xfrRespRuleActions = dnsdist::rules::getResponseRuleChain(chains, dnsdist::rules::ResponseRuleChain::XFRResponseRules);
449✔
1215

1216
  if (!applyRulesToResponse(xfrRespRuleActions, dnsResponse)) {
449!
1217
    return false;
×
1218
  }
×
1219

1220
  if (dnsResponse.isAsynchronous()) {
449!
1221
    return true;
×
1222
  }
×
1223

1224
  if (dnsResponse.ids.d_extendedError) {
449!
1225
    dnsdist::edns::addExtendedDNSError(dnsResponse.getMutableData(), dnsResponse.getMaximumSize(), dnsResponse.ids.d_extendedError->infoCode, dnsResponse.ids.d_extendedError->extraText);
×
1226
  }
×
1227

1228
  return true;
449✔
1229
}
449✔
1230

1231
void IncomingTCPConnectionState::handleXFRResponse(const struct timeval& now, TCPResponse&& response)
1232
{
449✔
1233
  if (std::this_thread::get_id() != d_creatorThreadID) {
449!
1234
    handleCrossProtocolResponse(now, std::move(response));
×
1235
    return;
×
1236
  }
×
1237

1238
  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
449✔
1239
  auto& ids = response.d_idstate;
449✔
1240
  std::shared_ptr<DownstreamState> backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
449!
1241
  DNSResponse dnsResponse(ids, response.d_buffer, backend);
449✔
1242
  dnsResponse.d_incomingTCPState = state;
449✔
1243
  memcpy(&response.d_cleartextDH, dnsResponse.getHeader().get(), sizeof(response.d_cleartextDH));
449✔
1244

1245
  if (!processXFRResponse(response.d_buffer, dnsResponse)) {
449!
1246
    state->terminateClientConnection();
×
1247
    return;
×
1248
  }
×
1249

1250
  queueResponse(state, now, std::move(response), true);
449✔
1251
}
449✔
1252

1253
void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write)
1254
{
17✔
1255
  vinfolog("Timeout while %s TCP client %s", (write ? "writing to" : "reading from"), state->d_ci.remote.toStringWithPort());
17!
1256
  DEBUGLOG("client timeout");
17✔
1257
  DEBUGLOG("Processed " << state->d_queriesCount << " queries, current count is " << state->d_currentQueriesCount << ", " << state->d_ownedConnectionsToBackend.size() << " owned connections, " << state->d_queuedResponses.size() << " response queued");
17✔
1258

1259
  if (write || state->d_currentQueriesCount == 0) {
17✔
1260
    ++state->d_ci.cs->tcpClientTimeouts;
11✔
1261
    state->d_ioState.reset();
11✔
1262
  }
11✔
1263
  else {
6✔
1264
    DEBUGLOG("Going idle");
6✔
1265
    /* we still have some queries in flight, let's just stop reading for now */
1266
    state->d_state = State::idle;
6✔
1267
    state->d_ioState->update(IOState::Done, handleIOCallback, state);
6✔
1268
  }
6✔
1269
}
17✔
1270

1271
static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1272
{
1,981✔
1273
  auto* threadData = boost::any_cast<TCPClientThreadData*>(param);
1,981✔
1274

1275
  std::unique_ptr<ConnectionInfo> citmp{nullptr};
1,981✔
1276
  try {
1,981✔
1277
    auto tmp = threadData->queryReceiver.receive();
1,981✔
1278
    if (!tmp) {
1,981!
1279
      return;
×
1280
    }
×
1281
    citmp = std::move(*tmp);
1,981✔
1282
  }
1,981✔
1283
  catch (const std::exception& e) {
1,981✔
1284
    throw std::runtime_error("Error while reading from the TCP query channel: " + std::string(e.what()));
×
1285
  }
×
1286

1287
  g_tcpclientthreads->decrementQueuedCount();
1,981✔
1288

1289
  timeval now{};
1,981✔
1290
  gettimeofday(&now, nullptr);
1,981✔
1291

1292
  if (citmp->cs->dohFrontend) {
1,981✔
1293
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
153✔
1294
    auto state = std::make_shared<IncomingHTTP2Connection>(std::move(*citmp), *threadData, now);
153✔
1295
    state->handleIO();
153✔
1296
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
153✔
1297
  }
153✔
1298
  else {
1,828✔
1299
    auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
1,828✔
1300
    state->handleIO();
1,828✔
1301
  }
1,828✔
1302
}
1,981✔
1303

1304
static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
1305
{
186✔
1306
  auto* threadData = boost::any_cast<TCPClientThreadData*>(param);
186✔
1307

1308
  std::unique_ptr<CrossProtocolQuery> cpq{nullptr};
186✔
1309
  try {
186✔
1310
    auto tmp = threadData->crossProtocolQueryReceiver.receive();
186✔
1311
    if (!tmp) {
186!
1312
      return;
×
1313
    }
×
1314
    cpq = std::move(*tmp);
186✔
1315
  }
186✔
1316
  catch (const std::exception& e) {
186✔
1317
    throw std::runtime_error("Error while reading from the TCP cross-protocol channel: " + std::string(e.what()));
×
1318
  }
×
1319

1320
  timeval now{};
186✔
1321
  gettimeofday(&now, nullptr);
186✔
1322

1323
  std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender();
186✔
1324
  auto query = std::move(cpq->query);
186✔
1325
  auto downstreamServer = std::move(cpq->downstream);
186✔
1326

1327
  try {
186✔
1328
    auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
186✔
1329

1330
    prependSizeToTCPQuery(query.d_buffer, query.d_idstate.d_proxyProtocolPayloadSize);
186✔
1331

1332
    vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr());
186✔
1333

1334
    downstream->queueQuery(tqs, std::move(query));
186✔
1335
  }
186✔
1336
  catch (...) {
186✔
1337
    tqs->notifyIOError(now, std::move(query));
×
1338
  }
×
1339
}
186✔
1340

1341
static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param)
1342
{
257✔
1343
  auto* threadData = boost::any_cast<TCPClientThreadData*>(param);
257✔
1344

1345
  std::unique_ptr<TCPCrossProtocolResponse> cpr{nullptr};
257✔
1346
  try {
257✔
1347
    auto tmp = threadData->crossProtocolResponseReceiver.receive();
257✔
1348
    if (!tmp) {
257!
1349
      return;
×
1350
    }
×
1351
    cpr = std::move(*tmp);
257✔
1352
  }
257✔
1353
  catch (const std::exception& e) {
257✔
1354
    throw std::runtime_error("Error while reading from the TCP cross-protocol response: " + std::string(e.what()));
×
1355
  }
×
1356

1357
  auto& response = *cpr;
257✔
1358

1359
  try {
257✔
1360
    if (response.d_response.d_buffer.empty()) {
257✔
1361
      response.d_state->notifyIOError(response.d_now, std::move(response.d_response));
24✔
1362
    }
24✔
1363
    else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) {
233!
1364
      response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
×
1365
    }
×
1366
    else {
233✔
1367
      response.d_state->handleResponse(response.d_now, std::move(response.d_response));
233✔
1368
    }
233✔
1369
  }
257✔
1370
  catch (...) {
257✔
1371
    /* no point bubbling up from there */
1372
  }
×
1373
}
257✔
1374

1375
struct TCPAcceptorParam
1376
{
1377
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1378
  ClientState& clientState;
1379
  ComboAddress local;
1380
  int socket{-1};
1381
};
1382

1383
static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData);
1384

1385
static void scanForTimeouts(const TCPClientThreadData& data, const timeval& now)
1386
{
6,477✔
1387
  auto expiredReadConns = data.mplexer->getTimeouts(now, false);
6,477✔
1388
  for (const auto& cbData : expiredReadConns) {
6,477!
1389
    if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
×
1390
      auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
×
1391
      if (cbData.first == state->d_handler.getDescriptor()) {
×
1392
        vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
×
1393
        state->handleTimeout(state, false);
×
1394
      }
×
1395
    }
×
1396
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1397
    else if (cbData.second.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
×
1398
      auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(cbData.second);
1399
      if (cbData.first == state->d_handler.getDescriptor()) {
×
1400
        vinfolog("Timeout (read) from remote H2 client %s", state->d_ci.remote.toStringWithPort());
×
1401
        std::shared_ptr<IncomingTCPConnectionState> parentState = state;
1402
        state->handleTimeout(parentState, false);
1403
      }
1404
    }
1405
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1406
    else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
×
1407
      auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
×
1408
      vinfolog("Timeout (read) from remote backend %s", conn->getBackendName());
×
1409
      conn->handleTimeout(now, false);
×
1410
    }
×
1411
  }
×
1412

1413
  auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
6,477✔
1414
  for (const auto& cbData : expiredWriteConns) {
6,477!
1415
    if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
×
1416
      auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
×
1417
      if (cbData.first == state->d_handler.getDescriptor()) {
×
1418
        vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
×
1419
        state->handleTimeout(state, true);
×
1420
      }
×
1421
    }
×
1422
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1423
    else if (cbData.second.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
×
1424
      auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(cbData.second);
1425
      if (cbData.first == state->d_handler.getDescriptor()) {
×
1426
        vinfolog("Timeout (write) from remote H2 client %s", state->d_ci.remote.toStringWithPort());
×
1427
        std::shared_ptr<IncomingTCPConnectionState> parentState = state;
1428
        state->handleTimeout(parentState, true);
1429
      }
1430
    }
1431
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1432
    else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
×
1433
      auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
×
1434
      vinfolog("Timeout (write) from remote backend %s", conn->getBackendName());
×
1435
      conn->handleTimeout(now, true);
×
1436
    }
×
1437
  }
×
1438
}
6,477✔
1439

1440
static void dumpTCPStates(const TCPClientThreadData& data)
1441
{
×
1442
  /* just to keep things clean in the output, debug only */
1443
  static std::mutex s_lock;
×
1444
  std::lock_guard<decltype(s_lock)> lck(s_lock);
×
1445
  if (g_tcpStatesDumpRequested > 0) {
×
1446
    /* no race here, we took the lock so it can only be increased in the meantime */
1447
    --g_tcpStatesDumpRequested;
×
1448
    infolog("Dumping the TCP states, as requested:");
×
1449
    data.mplexer->runForAllWatchedFDs([](bool isRead, int desc, const FDMultiplexer::funcparam_t& param, struct timeval ttd) {
×
1450
      timeval lnow{};
×
1451
      gettimeofday(&lnow, nullptr);
×
1452
      if (ttd.tv_sec > 0) {
×
1453
        infolog("- Descriptor %d is in %s state, TTD in %d", desc, (isRead ? "read" : "write"), (ttd.tv_sec - lnow.tv_sec));
×
1454
      }
×
1455
      else {
×
1456
        infolog("- Descriptor %d is in %s state, no TTD set", desc, (isRead ? "read" : "write"));
×
1457
      }
×
1458

1459
      if (param.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
×
1460
        auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
×
1461
        infolog(" - %s", state->toString());
×
1462
      }
×
1463
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1464
      else if (param.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) {
×
1465
        auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
1466
        infolog(" - %s", state->toString());
1467
      }
1468
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1469
      else if (param.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
×
1470
        auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param);
×
1471
        infolog(" - %s", conn->toString());
×
1472
      }
×
1473
      else if (param.type() == typeid(TCPClientThreadData*)) {
×
1474
        infolog(" - Worker thread pipe");
×
1475
      }
×
1476
    });
×
1477
    infolog("The TCP/DoT client cache has %d active and %d idle outgoing connections cached", t_downstreamTCPConnectionsManager.getActiveCount(), t_downstreamTCPConnectionsManager.getIdleCount());
×
1478
  }
×
1479
}
×
1480

1481
// NOLINTNEXTLINE(performance-unnecessary-value-param): you are wrong, clang-tidy, go home
1482
static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates)
1483
{
2,989✔
1484
  /* we get launched with a pipe on which we receive file descriptors from clients that we own
1485
     from that point on */
1486

1487
  setThreadName("dnsdist/tcpClie");
2,989✔
1488

1489
  try {
2,989✔
1490
    TCPClientThreadData data;
2,989✔
1491
    data.crossProtocolResponseSender = std::move(crossProtocolResponseSender);
2,989✔
1492
    data.queryReceiver = std::move(queryReceiver);
2,989✔
1493
    data.crossProtocolQueryReceiver = std::move(crossProtocolQueryReceiver);
2,989✔
1494
    data.crossProtocolResponseReceiver = std::move(crossProtocolResponseReceiver);
2,989✔
1495

1496
    data.mplexer->addReadFD(data.queryReceiver.getDescriptor(), handleIncomingTCPQuery, &data);
2,989✔
1497
    data.mplexer->addReadFD(data.crossProtocolQueryReceiver.getDescriptor(), handleCrossProtocolQuery, &data);
2,989✔
1498
    data.mplexer->addReadFD(data.crossProtocolResponseReceiver.getDescriptor(), handleCrossProtocolResponse, &data);
2,989✔
1499

1500
    /* only used in single acceptor mode for now */
1501
    std::vector<TCPAcceptorParam> acceptParams;
2,989✔
1502
    acceptParams.reserve(tcpAcceptStates.size());
2,989✔
1503

1504
    for (auto& state : tcpAcceptStates) {
2,989!
1505
      acceptParams.emplace_back(TCPAcceptorParam{*state, state->local, state->tcpFD});
×
1506
      for (const auto& [addr, socket] : state->d_additionalAddresses) {
×
1507
        acceptParams.emplace_back(TCPAcceptorParam{*state, addr, socket});
×
1508
      }
×
1509
    }
×
1510

1511
    auto acceptCallback = [&data](int socket, FDMultiplexer::funcparam_t& funcparam) {
2,989✔
1512
      const auto* acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam);
×
1513
      acceptNewConnection(*acceptorParam, &data);
×
1514
    };
×
1515

1516
    for (const auto& param : acceptParams) {
2,989!
1517
      setNonBlocking(param.socket);
×
1518
      data.mplexer->addReadFD(param.socket, acceptCallback, &param);
×
1519
    }
×
1520

1521
    timeval now{};
2,989✔
1522
    gettimeofday(&now, nullptr);
2,989✔
1523
    time_t lastTimeoutScan = now.tv_sec;
2,989✔
1524

1525
    for (;;) {
23,107✔
1526
      data.mplexer->run(&now);
23,107✔
1527

1528
      try {
23,107✔
1529
        t_downstreamTCPConnectionsManager.cleanupClosedConnections(now);
23,107✔
1530

1531
        if (now.tv_sec > lastTimeoutScan) {
23,107✔
1532
          lastTimeoutScan = now.tv_sec;
6,576✔
1533
          scanForTimeouts(data, now);
6,576✔
1534

1535
          if (g_tcpStatesDumpRequested > 0) {
6,576!
1536
            dumpTCPStates(data);
×
1537
          }
×
1538
        }
6,576✔
1539
      }
23,107✔
1540
      catch (const std::exception& e) {
23,107✔
1541
        warnlog("Error in TCP worker thread: %s", e.what());
×
1542
      }
×
1543
    }
23,107✔
1544
  }
2,989✔
1545
  catch (const std::exception& e) {
2,989✔
1546
    errlog("Fatal error in TCP worker thread: %s", e.what());
×
1547
  }
×
1548
}
2,989✔
1549

1550
static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData)
1551
{
2,363✔
1552
  auto& clientState = param.clientState;
2,363✔
1553
  const bool checkACL = clientState.dohFrontend == nullptr || (!clientState.dohFrontend->d_trustForwardedForHeader && clientState.dohFrontend->d_earlyACLDrop);
2,363!
1554
  const int socket = param.socket;
2,363✔
1555
  bool tcpClientCountIncremented = false;
2,363✔
1556
  ComboAddress remote;
2,363✔
1557
  remote.sin4.sin_family = param.local.sin4.sin_family;
2,363✔
1558

1559
  tcpClientCountIncremented = false;
2,363✔
1560
  try {
2,363✔
1561
    socklen_t remlen = remote.getSocklen();
2,363✔
1562
    ConnectionInfo connInfo(&clientState);
2,363✔
1563
#ifdef HAVE_ACCEPT4
2,363✔
1564
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1565
    connInfo.fd = accept4(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK);
2,363✔
1566
#else
1567
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1568
    connInfo.fd = accept(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen);
1569
#endif
1570
    // will be decremented when the ConnectionInfo object is destroyed, no matter the reason
1571
    auto concurrentConnections = ++clientState.tcpCurrentConnections;
2,363✔
1572

1573
    if (connInfo.fd < 0) {
2,363!
1574
      throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
×
1575
    }
×
1576

1577
    if (checkACL && !dnsdist::configuration::getCurrentRuntimeConfiguration().d_ACL.match(remote)) {
2,363✔
1578
      ++dnsdist::metrics::g_stats.aclDrops;
9✔
1579
      vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
9✔
1580
      return;
9✔
1581
    }
9✔
1582

1583
    if (clientState.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > clientState.d_tcpConcurrentConnectionsLimit) {
2,354✔
1584
      vinfolog("Dropped TCP connection from %s because of concurrent connections limit", remote.toStringWithPort());
3✔
1585
      return;
3✔
1586
    }
3✔
1587

1588
    if (concurrentConnections > clientState.tcpMaxConcurrentConnections.load()) {
2,351✔
1589
      clientState.tcpMaxConcurrentConnections.store(concurrentConnections);
473✔
1590
    }
473✔
1591

1592
#ifndef HAVE_ACCEPT4
1593
    if (!setNonBlocking(connInfo.fd)) {
1594
      return;
1595
    }
1596
#endif
1597

1598
    setTCPNoDelay(connInfo.fd); // disable NAGLE
2,351✔
1599

1600
    const auto maxTCPQueuedConnections = dnsdist::configuration::getImmutableConfiguration().d_maxTCPQueuedConnections;
2,351✔
1601
    if (maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= maxTCPQueuedConnections) {
2,351!
1602
      vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
×
1603
      return;
×
1604
    }
×
1605

1606
    if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) {
2,351✔
1607
      vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
2✔
1608
      return;
2✔
1609
    }
2✔
1610
    tcpClientCountIncremented = true;
2,349✔
1611

1612
    vinfolog("Got TCP connection from %s", remote.toStringWithPort());
2,349✔
1613

1614
    connInfo.remote = remote;
2,349✔
1615

1616
    if (threadData == nullptr) {
2,349✔
1617
      if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(connInfo)))) {
1,981!
1618
        if (tcpClientCountIncremented) {
×
1619
          dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote);
×
1620
        }
×
1621
      }
×
1622
    }
1,981✔
1623
    else {
368✔
1624
      timeval now{};
368✔
1625
      gettimeofday(&now, nullptr);
368✔
1626

1627
      if (connInfo.cs->dohFrontend) {
368!
1628
#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
1629
        auto state = std::make_shared<IncomingHTTP2Connection>(std::move(connInfo), *threadData, now);
1630
        state->handleIO();
1631
#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */
1632
      }
×
1633
      else {
368✔
1634
        auto state = std::make_shared<IncomingTCPConnectionState>(std::move(connInfo), *threadData, now);
368✔
1635
        state->handleIO();
368✔
1636
      }
368✔
1637
    }
368✔
1638
  }
2,349✔
1639
  catch (const std::exception& e) {
2,363✔
1640
    errlog("While reading a TCP question: %s", e.what());
×
1641
    if (tcpClientCountIncremented) {
×
1642
      dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote);
×
1643
    }
×
1644
  }
×
1645
  catch (...) {
2,363✔
1646
  }
×
1647
}
2,363✔
1648

1649
/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
1650
   they will hand off to worker threads & spawn more of them if required
1651
*/
1652
#ifndef USE_SINGLE_ACCEPTOR_THREAD
1653
void tcpAcceptorThread(const std::vector<ClientState*>& states)
1654
{
368✔
1655
  setThreadName("dnsdist/tcpAcce");
368✔
1656

1657
  std::vector<TCPAcceptorParam> params;
368✔
1658
  params.reserve(states.size());
368✔
1659

1660
  for (const auto& state : states) {
368✔
1661
    params.emplace_back(TCPAcceptorParam{*state, state->local, state->tcpFD});
368✔
1662
    for (const auto& [addr, socket] : state->d_additionalAddresses) {
368!
1663
      params.emplace_back(TCPAcceptorParam{*state, addr, socket});
×
1664
    }
×
1665
  }
368✔
1666

1667
  if (params.size() == 1) {
368!
1668
    while (true) {
2,731✔
1669
      acceptNewConnection(params.at(0), nullptr);
2,363✔
1670
    }
2,363✔
1671
  }
368✔
1672
  else {
×
1673
    auto acceptCallback = [](int socket, FDMultiplexer::funcparam_t& funcparam) {
×
1674
      const auto* acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam);
×
1675
      acceptNewConnection(*acceptorParam, nullptr);
×
1676
    };
×
1677

1678
    auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(params.size()));
×
1679
    for (const auto& param : params) {
×
1680
      mplexer->addReadFD(param.socket, acceptCallback, &param);
×
1681
    }
×
1682

1683
    timeval now{};
×
1684
    while (true) {
×
1685
      mplexer->run(&now, -1);
×
1686
    }
×
1687
  }
×
1688
}
368✔
1689
#endif
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc