• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 12595591960

03 Jan 2025 09:27AM UTC coverage: 62.774% (+2.5%) from 60.245%
12595591960

Pull #15008

github

web-flow
Merge c2a2749d3 into 788f396a7
Pull Request #15008: Do not follow CNAME records for ANY or CNAME queries

30393 of 78644 branches covered (38.65%)

Branch coverage included in aggregate %.

105822 of 138350 relevant lines covered (76.49%)

4613078.44 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.77
/pdns/tcpiohandler.hh
1

2
#pragma once
3
#include <memory>
4
/* needed for proper TCP_FASTOPEN_CONNECT detection */
5
#include <netinet/tcp.h>
6

7
#include "iputils.hh"
8
#include "libssl.hh"
9
#include "misc.hh"
10
#include "noinitvector.hh"
11

12
/* Async is only returned for TLS connections, if OpenSSL's async mode has been enabled */
13
enum class IOState : uint8_t { Done, NeedRead, NeedWrite, Async };
14

15
class TLSSession
16
{
17
public:
18
  virtual ~TLSSession() = default;
73✔
19
};
20

21
class TLSConnection
22
{
23
public:
24
  virtual ~TLSConnection() = default;
899✔
25
  virtual void doHandshake() = 0;
26
  virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0;
27
  virtual void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) = 0;
28
  virtual IOState tryHandshake() = 0;
29
  virtual size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false) = 0;
30
  virtual size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) = 0;
31
  virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0;
32
  virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0;
33
  virtual std::string getServerNameIndication() const = 0;
34
  virtual std::vector<uint8_t> getNextProtocol() const = 0;
35
  virtual LibsslTLSVersion getTLSVersion() const = 0;
36
  virtual bool hasSessionBeenResumed() const = 0;
37
  virtual std::vector<std::unique_ptr<TLSSession>> getSessions() = 0;
38
  virtual void setSession(std::unique_ptr<TLSSession>& session) = 0;
39
  virtual bool isUsable() const = 0;
40
  virtual std::vector<int> getAsyncFDs() = 0;
41
  virtual void close() = 0;
42

43
  void setUnknownTicketKey()
44
  {
6✔
45
    d_unknownTicketKey = true;
6✔
46
  }
6✔
47

48
  bool getUnknownTicketKey() const
49
  {
372✔
50
    return d_unknownTicketKey;
372✔
51
  }
372✔
52

53
  void setResumedFromInactiveTicketKey()
54
  {
8✔
55
    d_resumedFromInactiveTicketKey = true;
8✔
56
  }
8✔
57

58
  bool getResumedFromInactiveTicketKey() const
59
  {
372✔
60
    return d_resumedFromInactiveTicketKey;
372✔
61
  }
372✔
62

63
protected:
64
  int d_socket{-1};
65
  bool d_unknownTicketKey{false};
66
  bool d_resumedFromInactiveTicketKey{false};
67
};
68

69
class TLSCtx
70
{
71
public:
72
  TLSCtx()
73
  {
176✔
74
    d_rotatingTicketsKey.clear();
176✔
75
  }
176✔
76
  virtual ~TLSCtx() = default;
82✔
77
  virtual std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) = 0;
78
  virtual std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) = 0;
79
  virtual void rotateTicketsKey(time_t now) = 0;
80
  virtual void loadTicketsKeys(const std::string& /* file */)
81
  {
×
82
    throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
×
83
  }
×
84
  virtual void loadTicketsKey(const std::string& /* key */)
85
  {
×
86
    throw std::runtime_error("This TLS backend does not have the capability to load a ticket key");
×
87
  }
×
88
  void handleTicketsKeyRotation(time_t now)
89
  {
333✔
90
    if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
333!
91
      if (d_rotatingTicketsKey.test_and_set()) {
58!
92
        /* someone is already rotating */
93
        return;
×
94
      }
×
95
      try {
58✔
96
        rotateTicketsKey(now);
58✔
97
        d_rotatingTicketsKey.clear();
58✔
98
      }
58✔
99
      catch(const std::runtime_error& e) {
58✔
100
        d_rotatingTicketsKey.clear();
×
101
        throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what());
×
102
      }
×
103
      catch(...) {
58✔
104
        d_rotatingTicketsKey.clear();
×
105
        throw;
×
106
      }
×
107
    }
58✔
108
  }
333✔
109

110
  time_t getNextTicketsKeyRotation() const
111
  {
12✔
112
    return d_ticketsKeyNextRotation;
12✔
113
  }
12✔
114

115
  virtual size_t getTicketsKeysCount() = 0;
116
  virtual std::string getName() const = 0;
117

118
  using tickets_key_added_hook = std::function<void(const std::string& key)>;
119

120
  static void setTicketsKeyAddedHook(const tickets_key_added_hook& hook)
121
  {
4✔
122
    TLSCtx::s_ticketsKeyAddedHook = hook;
4✔
123
  }
4✔
124
  static const tickets_key_added_hook& getTicketsKeyAddedHook()
125
  {
8✔
126
    return TLSCtx::s_ticketsKeyAddedHook;
8✔
127
  }
8✔
128
  static bool hasTicketsKeyAddedHook()
129
  {
466✔
130
    return TLSCtx::s_ticketsKeyAddedHook != nullptr;
466✔
131
  }
466✔
132
protected:
133
  std::atomic_flag d_rotatingTicketsKey;
134
  std::atomic<time_t> d_ticketsKeyNextRotation{0};
135
  time_t d_ticketsKeyRotationDelay{0};
136

137
private:
138
  static tickets_key_added_hook s_ticketsKeyAddedHook;
139
};
140

141
class TLSFrontend
142
{
143
public:
144
  enum class ALPN : uint8_t { Unset, DoT, DoH };
145

146
  TLSFrontend(ALPN alpn): d_alpn(alpn)
147
  {
196✔
148
  }
196✔
149

150
  TLSFrontend(std::shared_ptr<TLSCtx> ctx): d_ctx(std::move(ctx))
151
  {
71✔
152
  }
71✔
153

154
  bool setupTLS();
155

156
  void rotateTicketsKey(time_t now)
157
  {
22✔
158
    if (d_ctx != nullptr) {
22!
159
      d_ctx->rotateTicketsKey(now);
22✔
160
    }
22✔
161
  }
22✔
162

163
  void loadTicketsKeys(const std::string& file)
164
  {
6✔
165
    if (d_ctx != nullptr) {
6!
166
      d_ctx->loadTicketsKeys(file);
6✔
167
    }
6✔
168
  }
6✔
169

170
  void loadTicketsKey(const std::string& key)
171
  {
×
172
    if (d_ctx != nullptr) {
×
173
      d_ctx->loadTicketsKey(key);
×
174
    }
×
175
  }
×
176

177
  std::shared_ptr<TLSCtx> getContext()
178
  {
1,172✔
179
    return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire);
1,172✔
180
  }
1,172✔
181

182
  void cleanup()
183
  {
×
184
    d_ctx.reset();
×
185
  }
×
186

187
  size_t getTicketsKeysCount()
188
  {
12✔
189
    if (d_ctx != nullptr) {
12!
190
      return d_ctx->getTicketsKeysCount();
12✔
191
    }
12✔
192

193
    return 0;
×
194
  }
12✔
195

196
  static std::string timeToString(time_t rotationTime)
197
  {
12✔
198
    char buf[20];
12✔
199
    struct tm date_tm;
12✔
200

201
    localtime_r(&rotationTime, &date_tm);
12✔
202
    strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm);
12✔
203

204
    return std::string(buf);
12✔
205
  }
12✔
206

207
  time_t getTicketsKeyRotationDelay() const
208
  {
×
209
    return d_tlsConfig.d_ticketsKeyRotationDelay;
×
210
  }
×
211

212
  std::string getNextTicketsKeyRotation() const
213
  {
12✔
214
    std::string res;
12✔
215

216
    if (d_ctx != nullptr) {
12!
217
      res = timeToString(d_ctx->getNextTicketsKeyRotation());
12✔
218
    }
12✔
219

220
    return res;
12✔
221
  }
12✔
222

223
  std::string getRequestedProvider() const
224
  {
×
225
    return d_provider;
×
226
  }
×
227

228
  std::string getEffectiveProvider() const
229
  {
6✔
230
    if (d_ctx) {
6!
231
      return d_ctx->getName();
6✔
232
    }
6✔
233
    return "";
×
234
  }
6✔
235

236
  TLSConfig d_tlsConfig;
237
  TLSErrorCounters d_tlsCounters;
238
  ComboAddress d_addr;
239
  std::string d_provider;
240
  ALPN d_alpn{ALPN::Unset};
241
  /* whether the proxy protocol is inside or outside the TLS layer */
242
  bool d_proxyProtocolOutsideTLS{false};
243
protected:
244
  std::shared_ptr<TLSCtx> d_ctx{nullptr};
245
};
246

247
class TCPIOHandler
248
{
249
public:
250
  TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx) :
251
    d_socket(socket)
252
  {
6,273✔
253
    if (ctx) {
6,273✔
254
      d_conn = ctx->getClientConnection(host, hostIsAddr, d_socket, timeout);
519✔
255
    }
519✔
256
  }
6,273✔
257

258
  TCPIOHandler(int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx, time_t now) :
259
    d_socket(socket)
260
  {
4,621✔
261
    if (ctx) {
4,621✔
262
      d_conn = ctx->getConnection(d_socket, timeout, now);
801✔
263
    }
801✔
264
  }
4,621✔
265

266
  ~TCPIOHandler()
267
  {
22,263✔
268
    close();
22,263✔
269
  }
22,263✔
270

271
  void close()
272
  {
55,095✔
273
    if (d_conn) {
55,095✔
274
      d_conn->close();
3,242✔
275
      d_conn.reset();
3,242✔
276
    }
3,242✔
277

278
    if (d_socket != -1) {
55,095✔
279
      shutdown(d_socket, SHUT_RDWR);
21,920✔
280
      ::close(d_socket);
21,920✔
281
      d_socket = -1;
21,920✔
282
    }
21,920✔
283
  }
55,095✔
284

285
  int getDescriptor() const
286
  {
30,092✔
287
    return d_socket;
30,092✔
288
  }
30,092✔
289

290
  IOState tryConnect(bool fastOpen, const ComboAddress& remote)
291
  {
4,113✔
292
    d_remote = remote;
4,113✔
293

294
#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
4,113✔
295
    if (fastOpen) {
4,113✔
296
      int value = 1;
10✔
297
      int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
10✔
298
      if (res == 0) {
10!
299
        fastOpen = false;
10✔
300
      }
10✔
301
    }
10✔
302
#endif /* TCP_FASTOPEN_CONNECT */
4,113✔
303

304
#ifdef MSG_FASTOPEN
4,113✔
305
    if (!d_conn && fastOpen) {
4,113!
306
      d_fastOpen = true;
×
307
    }
×
308
    else {
4,113✔
309
      if (!s_disableConnectForUnitTests) {
4,113✔
310
        SConnectWithTimeout(d_socket, remote, /* no timeout, we will handle it ourselves */ timeval{0,0});
3,958✔
311
      }
3,958✔
312
    }
4,113✔
313
#else
314
    if (!s_disableConnectForUnitTests) {
315
      SConnectWithTimeout(d_socket, remote, /* no timeout, we will handle it ourselves */ timeval{0,0});
316
    }
317
#endif /* MSG_FASTOPEN */
318

319
    if (d_conn) {
4,113✔
320
      return d_conn->tryConnect(fastOpen, remote);
385✔
321
    }
385✔
322

323
    return IOState::Done;
3,728✔
324
  }
4,113✔
325

326
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout)
327
  {
7✔
328
    d_remote = remote;
7✔
329

330
#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
7✔
331
    if (fastOpen) {
7!
332
      int value = 1;
×
333
      int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
×
334
      if (res == 0) {
×
335
        fastOpen = false;
×
336
      }
×
337
    }
×
338
#endif /* TCP_FASTOPEN_CONNECT */
7✔
339

340
#ifdef MSG_FASTOPEN
7✔
341
    if (!d_conn && fastOpen) {
7!
342
      d_fastOpen = true;
×
343
    }
×
344
    else {
7✔
345
      if (!s_disableConnectForUnitTests) {
7!
346
        SConnectWithTimeout(d_socket, remote, timeout);
7✔
347
      }
7✔
348
    }
7✔
349
#else
350
    if (!s_disableConnectForUnitTests) {
351
      SConnectWithTimeout(d_socket, remote, timeout);
352
    }
353
#endif /* MSG_FASTOPEN */
354

355
    if (d_conn) {
7✔
356
      d_conn->connect(fastOpen, remote, timeout);
4✔
357
    }
4✔
358
  }
7✔
359

360
  IOState tryHandshake()
361
  {
5,074✔
362
    if (d_conn) {
5,074✔
363
      return d_conn->tryHandshake();
1,254✔
364
    }
1,254✔
365
    return IOState::Done;
3,820✔
366
  }
5,074✔
367

368
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0,0}, bool allowIncomplete=false)
369
  {
×
370
    if (d_conn) {
×
371
      return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
×
372
    } else {
×
373
      return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
×
374
    }
×
375
  }
×
376

377
  /* Tries to read exactly toRead - pos bytes into the buffer, starting at position pos.
378
     Updates pos everytime a successful read occurs,
379
     throws an std::runtime_error in case of IO error,
380
     return Done when toRead bytes have been read, needRead or needWrite if the IO operation
381
     would block.
382
  */
383
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false, bool bypassFilters=false)
384
  {
218,746✔
385
    if (buffer.size() < toRead || pos >= toRead) {
218,751!
386
      throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
×
387
    }
×
388

389
    if (!bypassFilters && d_conn) {
218,746✔
390
      return d_conn->tryRead(buffer, pos, toRead, allowIncomplete);
137,948✔
391
    }
137,948✔
392

393
    do {
81,192✔
394
      ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - pos);
81,192✔
395
      if (res == 0) {
81,192✔
396
        throw runtime_error("EOF while reading message");
14,477✔
397
      }
14,477✔
398
      if (res < 0) {
66,715✔
399
        if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
21,482!
400
          return IOState::NeedRead;
21,357✔
401
        }
21,357✔
402
        else {
125✔
403
          throw std::runtime_error("Error while reading message: " + stringerror());
125✔
404
        }
125✔
405
      }
21,482✔
406

407
      pos += static_cast<size_t>(res);
45,233✔
408
      if (allowIncomplete) {
45,233✔
409
        break;
20✔
410
      }
20✔
411
    }
45,233✔
412
    while (pos < toRead);
80,798✔
413

414
    return IOState::Done;
44,839✔
415
  }
80,798✔
416

417
  /* Tries to write exactly toWrite - pos bytes from the buffer, starting at position pos.
418
     Updates pos everytime a successful write occurs,
419
     throws an std::runtime_error in case of IO error,
420
     return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
421
     would block.
422
  */
423
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite)
424
  {
95,831✔
425
    if (buffer.size() < toWrite || pos >= toWrite) {
95,831!
426
      throw std::out_of_range("Calling tryWrite() with a too small buffer (" + std::to_string(buffer.size()) + ") for a write of " + std::to_string(toWrite - pos) + " bytes starting at " + std::to_string(pos));
×
427
    }
×
428
    if (d_conn) {
95,831✔
429
      return d_conn->tryWrite(buffer, pos, toWrite);
69,140✔
430
    }
69,140✔
431

432
#ifdef MSG_FASTOPEN
26,691✔
433
    if (d_fastOpen) {
26,691!
434
      int socketFlags = MSG_FASTOPEN;
×
435
      size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos, &d_remote, nullptr, 0, socketFlags);
×
436
      if (sent > 0) {
×
437
        d_fastOpen = false;
×
438
        pos += sent;
×
439
      }
×
440

441
      if (pos < toWrite) {
×
442
        return IOState::NeedWrite;
×
443
      }
×
444

445
      return IOState::Done;
×
446
    }
×
447
#endif /* MSG_FASTOPEN */
26,691✔
448

449
    do {
26,706✔
450
      ssize_t res = ::write(d_socket, reinterpret_cast<const char*>(&buffer.at(pos)), toWrite - pos);
26,706✔
451

452
      if (res == 0) {
26,706!
453
        throw runtime_error("EOF while sending message");
×
454
      }
×
455
      if (res < 0) {
26,706✔
456
        if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
15!
457
          return IOState::NeedWrite;
15✔
458
        }
15✔
459
        else {
×
460
          throw std::runtime_error("Error while writing message: " + stringerror());
×
461
        }
×
462
      }
15✔
463

464
      pos += static_cast<size_t>(res);
26,691✔
465
    }
26,691✔
466
    while (pos < toWrite);
26,691✔
467

468
    return IOState::Done;
26,676✔
469
  }
26,691✔
470

471
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout)
472
  {
×
473
    if (d_conn) {
×
474
      return d_conn->write(buffer, bufferSize, writeTimeout);
×
475
    }
×
476

×
477
#ifdef MSG_FASTOPEN
×
478
    if (d_fastOpen) {
×
479
      int socketFlags = MSG_FASTOPEN;
×
480
      size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(buffer), bufferSize, &d_remote, nullptr, 0, socketFlags);
×
481
      if (sent > 0) {
×
482
        d_fastOpen = false;
×
483
      }
×
484

×
485
      return sent;
×
486
    }
×
487
#endif /* MSG_FASTOPEN */
×
488

×
489
    return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
×
490
  }
×
491

492
  std::string getServerNameIndication() const
493
  {
22,875✔
494
    if (d_conn) {
22,875✔
495
      return d_conn->getServerNameIndication();
20,893✔
496
    }
20,893✔
497
    return std::string();
1,982✔
498
  }
22,875✔
499

500
  std::vector<uint8_t> getNextProtocol() const
501
  {
159✔
502
    if (d_conn) {
159!
503
      return d_conn->getNextProtocol();
159✔
504
    }
159✔
505
    return std::vector<uint8_t>();
×
506
  }
159✔
507

508
  LibsslTLSVersion getTLSVersion() const
509
  {
20,896✔
510
    if (d_conn) {
20,896!
511
      return d_conn->getTLSVersion();
20,896✔
512
    }
20,896✔
513
    return LibsslTLSVersion::Unknown;
×
514
  }
20,896✔
515

516
  bool isTLS() const
517
  {
138,925✔
518
    return d_conn != nullptr;
138,925✔
519
  }
138,925✔
520

521
  bool hasTLSSessionBeenResumed() const
522
  {
1,126✔
523
    return d_conn && d_conn->hasSessionBeenResumed();
1,126!
524
  }
1,126✔
525

526
  bool getResumedFromInactiveTicketKey() const
527
  {
372✔
528
    return d_conn && d_conn->getResumedFromInactiveTicketKey();
372!
529
  }
372✔
530

531
  bool getUnknownTicketKey() const
532
  {
372✔
533
    return d_conn && d_conn->getUnknownTicketKey();
372!
534
  }
372✔
535

536
  void setTLSSession(std::unique_ptr<TLSSession>& session)
537
  {
146✔
538
    if (d_conn != nullptr) {
146!
539
      d_conn->setSession(session);
146✔
540
    }
146✔
541
  }
146✔
542

543
  std::vector<std::unique_ptr<TLSSession>> getTLSSessions()
544
  {
293✔
545
    if (!d_conn) {
293!
546
      throw std::runtime_error("Trying to get TLS sessions from a non-TLS handler");
×
547
    }
×
548

549
    return d_conn->getSessions();
293✔
550
  }
293✔
551

552
  bool isUsable() const
553
  {
192✔
554
    if (!d_conn) {
192✔
555
      return isTCPSocketUsable(d_socket);
184✔
556
    }
184✔
557
    return d_conn->isUsable();
8✔
558
  }
192✔
559

560
  std::vector<int> getAsyncFDs()
561
  {
2,080✔
562
    if (!d_conn) {
2,080✔
563
      return {};
1,889✔
564
    }
1,889✔
565
    return d_conn->getAsyncFDs();
191✔
566
  }
2,080✔
567

568
  static const bool s_disableConnectForUnitTests;
569

570
private:
571
  std::unique_ptr<TLSConnection> d_conn{nullptr};
572
  ComboAddress d_remote;
573
  int d_socket{-1};
574
#ifdef MSG_FASTOPEN
575
  bool d_fastOpen{false};
576
#endif
577
};
578

579
struct TLSContextParameters
580
{
581
  std::string d_provider;
582
  std::string d_ciphers;
583
  std::string d_ciphers13;
584
  std::string d_caStore;
585
  std::string d_keyLogFile;
586
  TLSFrontend::ALPN d_alpn{TLSFrontend::ALPN::Unset};
587
  bool d_validateCertificates{true};
588
  bool d_releaseBuffers{true};
589
  bool d_enableRenegotiation{false};
590
  bool d_ktls{false};
591
};
592

593
std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
594
bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);
595
bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc