• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 12323094430

13 Dec 2024 09:11PM UTC coverage: 64.759% (-0.02%) from 64.78%
12323094430

Pull #14970

github

web-flow
Merge 3e4597ff7 into 3dfd8e317
Pull Request #14970: boost > std optional

37533 of 88820 branches covered (42.26%)

Branch coverage included in aggregate %.

17 of 19 new or added lines in 4 files covered. (89.47%)

79 existing lines in 16 files now uncovered.

125890 of 163537 relevant lines covered (76.98%)

4110788.26 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

68.18
/pdns/tcpiohandler.cc
1

2
#include "config.h"
3
#include "dolog.hh"
4
#include "iputils.hh"
5
#include "lock.hh"
6
#include "tcpiohandler.hh"
7

8
const bool TCPIOHandler::s_disableConnectForUnitTests = false;
9

10
namespace {
11
bool shouldDoVerboseLogging()
12
{
7✔
13
#ifdef DNSDIST
7✔
14
  return dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose;
7✔
15
#elif defined(RECURSOR)
16
  return false;
17
#else
18
  return true;
19
#endif
20
}
7✔
21
}
22

23
#ifdef HAVE_LIBSODIUM
24
#include <sodium.h>
25
#endif /* HAVE_LIBSODIUM */
26

27
TLSCtx::tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr};
28

29
static std::vector<std::vector<uint8_t>> getALPNVector(TLSFrontend::ALPN alpn, bool client)
30
{
105✔
31
  if (alpn == TLSFrontend::ALPN::DoT) {
105✔
32
    /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */
33
    return std::vector<std::vector<uint8_t>>{{'d', 'o', 't'}};
42✔
34
  }
42✔
35
  if (alpn == TLSFrontend::ALPN::DoH) {
63✔
36
    if (client) {
51✔
37
      /* we want to set the ALPN to h2, if only to mitigate the ALPACA attack */
38
      return std::vector<std::vector<uint8_t>>{{'h', '2'}};
22✔
39
    }
22✔
40
    /* For server contexts, we want to set the ALPN for DoH (note that h2o sets it own ALPN values):
41
       - HTTP/1.1 so that the OpenSSL callback ALPN accepts it, letting us later return a static response
42
       - HTTP/2
43
    */
44
    return std::vector<std::vector<uint8_t>>{{'h', '2'},{'h', 't', 't', 'p', '/', '1', '.', '1'}};
29✔
45
  }
51✔
46
  return {};
12✔
47
}
63✔
48

49
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
50
#ifdef HAVE_LIBSSL
51

52
#include <openssl/conf.h>
53
#include <openssl/err.h>
54
#include <openssl/rand.h>
55
#include <openssl/ssl.h>
56
#include <openssl/x509v3.h>
57

58
#include "libssl.hh"
59

60

61
class OpenSSLFrontendContext
62
{
63
public:
64
  OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
65
  {
51✔
66
    registerOpenSSLUser();
51✔
67

68
    auto [ctx, warnings] = libssl_init_server_context(tlsConfig, d_ocspResponses);
51✔
69
    for (const auto& warning : warnings) {
51✔
70
      warnlog("%s", warning);
2✔
71
    }
2✔
72
    d_tlsCtx = std::move(ctx);
51✔
73

74
    if (!d_tlsCtx) {
51!
75
      ERR_print_errors_fp(stderr);
×
76
      throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
×
77
    }
×
78
  }
51✔
79

80
  void cleanup()
81
  {
×
82
    d_tlsCtx.reset();
×
83

×
84
    unregisterOpenSSLUser();
×
85
  }
×
86

87
  OpenSSLTLSTicketKeysRing d_ticketKeys;
88
  std::map<int, std::string> d_ocspResponses;
89
  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
90
  pdns::UniqueFilePtr d_keyLogFile{nullptr};
91
};
92

93
class OpenSSLSession : public TLSSession
94
{
95
public:
96
  OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>&& sess): d_sess(std::move(sess))
97
  {
90✔
98
  }
90✔
99

100
  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative()
101
  {
42✔
102
    return std::move(d_sess);
42✔
103
  }
42✔
104

105
private:
106
  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
107
};
108

109
class OpenSSLTLSIOCtx;
110

111
class OpenSSLTLSConnection: public TLSConnection
112
{
113
public:
114
  /* server side connection */
115
  OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_timeout(timeout)
116
  {
259✔
117
    d_socket = socket;
259✔
118

119
    if (!d_conn) {
259!
120
      vinfolog("Error creating TLS object");
×
121
      if (shouldDoVerboseLogging()) {
×
122
        ERR_print_errors_fp(stderr);
×
123
      }
×
124
      throw std::runtime_error("Error creating TLS object");
×
125
    }
×
126

127
    if (!SSL_set_fd(d_conn.get(), d_socket)) {
259!
128
      throw std::runtime_error("Error assigning socket");
×
129
    }
×
130

131
    SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
259✔
132
  }
259✔
133

134
  /* client-side connection */
135
  OpenSSLTLSConnection(std::string hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_hostname(std::move(hostname)), d_timeout(timeout), d_isClient(true)
136
  {
79✔
137
    d_socket = socket;
79✔
138

139
    if (!d_conn) {
79!
140
      vinfolog("Error creating TLS object");
×
141
      if (shouldDoVerboseLogging()) {
×
142
        ERR_print_errors_fp(stderr);
×
143
      }
×
144
      throw std::runtime_error("Error creating TLS object");
×
145
    }
×
146

147
    if (!SSL_set_fd(d_conn.get(), d_socket)) {
79!
148
      throw std::runtime_error("Error assigning socket");
×
149
    }
×
150

151
    /* set outgoing Server Name Indication */
152
    if (!d_hostname.empty() && SSL_set_tlsext_host_name(d_conn.get(), d_hostname.c_str()) != 1) {
79!
153
      throw std::runtime_error("Error setting TLS SNI to " + d_hostname);
×
154
    }
×
155

156
    if (hostIsAddr) {
79✔
157
#if (OPENSSL_VERSION_NUMBER >= 0x10002000L)
2✔
158
      X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
2✔
159
      /* Enable automatic IP checks */
160
      X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
2✔
161
      if (X509_VERIFY_PARAM_set1_ip_asc(param, d_hostname.c_str()) != 1) {
2!
162
        throw std::runtime_error("Error setting TLS IP for certificate validation");
×
163
      }
×
164
#else
165
      /* no validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
166
#endif
167
    }
2✔
168
    else {
77✔
169
#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && defined(HAVE_SSL_SET_HOSTFLAGS) // grrr libressl
77✔
170
      SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
77✔
171
      if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
77!
172
        throw std::runtime_error("Error setting TLS hostname for certificate validation");
×
173
      }
×
174
#elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
175
      X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
176
      /* Enable automatic hostname checks */
177
      X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
178
      if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) {
179
        throw std::runtime_error("Error setting TLS hostname for certificate validation");
180
      }
181
#else
182
      /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
183
#endif
184
    }
77✔
185

186
    SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
79✔
187
  }
79✔
188

189
  std::vector<int> getAsyncFDs() override
190
  {
107✔
191
    std::vector<int> results;
107✔
192
#ifdef SSL_MODE_ASYNC
107✔
193
    if (SSL_waiting_for_async(d_conn.get()) != 1) {
107!
194
      return results;
107✔
195
    }
107✔
196

197
    OSSL_ASYNC_FD fds[32];
×
198
    size_t numfds = sizeof(fds)/sizeof(*fds);
×
199
    SSL_get_all_async_fds(d_conn.get(), nullptr, &numfds);
×
200
    if (numfds == 0) {
×
201
      return results;
×
202
    }
×
203

204
    SSL_get_all_async_fds(d_conn.get(), fds, &numfds);
×
205
    results.reserve(numfds);
×
206
    for (size_t idx = 0; idx < numfds; idx++) {
×
207
      results.push_back(fds[idx]);
×
208
    }
×
209
#endif
×
210
    return results;
×
211
  }
×
212

213
  IOState convertIORequestToIOState(int res) const
214
  {
1,108✔
215
    int error = SSL_get_error(d_conn.get(), res);
1,108✔
216
    if (error == SSL_ERROR_WANT_READ) {
1,108✔
217
      return IOState::NeedRead;
841✔
218
    }
841✔
219
    else if (error == SSL_ERROR_WANT_WRITE) {
267✔
220
      return IOState::NeedWrite;
12✔
221
    }
12✔
222
    else if (error == SSL_ERROR_SYSCALL) {
255✔
223
      if (errno == 0) {
2!
224
        throw std::runtime_error("TLS connection closed by remote end");
×
225
      }
×
226
      else {
2✔
227
        throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
2✔
228
      }
2✔
229
    }
2✔
230
    else if (error == SSL_ERROR_ZERO_RETURN) {
253✔
231
      throw std::runtime_error("TLS connection closed by remote end");
246✔
232
    }
246✔
233
#ifdef SSL_MODE_ASYNC
7✔
234
    else if (error == SSL_ERROR_WANT_ASYNC) {
7!
235
      return IOState::Async;
×
236
    }
×
237
#endif
7✔
238
    else {
7✔
239
      if (shouldDoVerboseLogging()) {
7✔
240
        throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error) + ") " + libssl_get_error_string());
1✔
241
      } else {
6✔
242
        throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
6✔
243
      }
6✔
244
    }
7✔
245
  }
1,108✔
246

247
  void handleIORequest(int res, const struct timeval& timeout)
248
  {
5✔
249
    auto state = convertIORequestToIOState(res);
5✔
250
    if (state == IOState::NeedRead) {
5✔
251
      res = waitForData(d_socket, timeout.tv_sec, timeout.tv_usec);
4✔
252
      if (res == 0) {
4!
253
        throw std::runtime_error("Timeout while reading from TLS connection");
×
254
      }
×
255
      else if (res < 0) {
4!
256
        throw std::runtime_error("Error waiting to read from TLS connection");
×
257
      }
×
258
    }
4✔
259
    else if (state == IOState::NeedWrite) {
1!
260
      res = waitForRWData(d_socket, false, timeout.tv_sec, timeout.tv_usec);
×
261
      if (res == 0) {
×
262
        throw std::runtime_error("Timeout while writing to TLS connection");
×
263
      }
×
264
      else if (res < 0) {
×
265
        throw std::runtime_error("Error waiting to write to TLS connection");
×
266
      }
×
267
    }
×
268
  }
5✔
269

270
  IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
271
  {
72✔
272
    /* sorry */
273
    (void) fastOpen;
72✔
274
    (void) remote;
72✔
275

276
    int res = SSL_connect(d_conn.get());
72✔
277
    if (res == 1) {
72!
UNCOV
278
      return IOState::Done;
×
UNCOV
279
    }
×
280
    else if (res < 0) {
72!
281
      return convertIORequestToIOState(res);
72✔
282
    }
72✔
283

284
    throw std::runtime_error("Error establishing a TLS connection");
×
285
  }
72✔
286

287
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval &timeout) override
288
  {
4✔
289
    /* sorry */
290
    (void) fastOpen;
4✔
291
    (void) remote;
4✔
292

293
    struct timeval start{0,0};
4✔
294
    struct timeval remainingTime = timeout;
4✔
295
    if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
4!
296
      gettimeofday(&start, nullptr);
4✔
297
    }
4✔
298

299
    int res = 0;
4✔
300
    do {
8✔
301
      res = SSL_connect(d_conn.get());
8✔
302
      if (res < 0) {
8✔
303
        handleIORequest(res, remainingTime);
5✔
304
      }
5✔
305

306
      if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
8!
307
        struct timeval now;
7✔
308
        gettimeofday(&now, nullptr);
7✔
309
        struct timeval elapsed = now - start;
7✔
310
        if (now < start || remainingTime < elapsed) {
7!
311
          throw runtime_error("Timeout while establishing TLS connection");
×
312
        }
×
313
        start = now;
7✔
314
        remainingTime = remainingTime - elapsed;
7✔
315
      }
7✔
316
    }
8✔
317
    while (res != 1);
8✔
318
  }
4✔
319

320
  IOState tryHandshake() override
321
  {
516✔
322
    if (isClient()) {
516!
323
      /* In client mode, the handshake is initiated by the call to SSL_connect()
324
         done from connect()/tryConnect().
325
         In blocking mode it does not return before the handshake has been finished,
326
         and in non-blocking mode calling SSL_connect() once is enough for SSL_write()
327
         and SSL_read() to transparently continue to negotiate the connection after that
328
         (equivalent to doing SSL_set_connect_state() plus trying to write).
329
      */
330
      return IOState::Done;
×
331
    }
×
332

333
    /* As explained above in the client-mode block, we only need to call SSL_accept() once
334
       for SSL_write() and SSL_read() to transparently continue to negotiate the connection after that.
335
       It is equivalent to calling SSL_set_accept_state() plus trying to read.
336
    */
337
    int res = SSL_accept(d_conn.get());
516✔
338
    if (res == 1) {
516✔
339
      return IOState::Done;
257✔
340
    }
257✔
341
    else if (res < 0) {
259!
342
      return convertIORequestToIOState(res);
259✔
343
    }
259✔
344

345
    throw std::runtime_error("Error accepting TLS connection");
×
346
  }
516✔
347

348
  void doHandshake() override
349
  {
×
350
    if (isClient()) {
×
351
      /* we are a client, nothing to do, see the non-blocking version */
352
      return;
×
353
    }
×
354

355
    int res = 0;
×
356
    do {
×
357
      res = SSL_accept(d_conn.get());
×
358
      if (res < 0) {
×
359
        handleIORequest(res, d_timeout);
×
360
      }
×
361
    }
×
362
    while (res < 0);
×
363

364
    if (res != 1) {
×
365
      throw std::runtime_error("Error accepting TLS connection");
×
366
    }
×
367
  }
×
368

369
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
370
  {
1,043✔
371
    if (isClient() && !d_connected) {
1,043✔
372
      if (d_ktls) {
149!
373
        /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
374
        SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get()));
×
375
      }
×
376
    }
149✔
377

378
    do {
1,043✔
379
      int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
1,043✔
380
      if (res <= 0) {
1,043✔
381
        return convertIORequestToIOState(res);
87✔
382
      }
87✔
383
      else {
956✔
384
        pos += static_cast<size_t>(res);
956✔
385
      }
956✔
386
    }
1,043✔
387
    while (pos < toWrite);
1,043!
388

389
    if (!d_connected) {
956✔
390
      d_connected = true;
290✔
391
    }
290✔
392

393
    return IOState::Done;
956✔
394
  }
1,043✔
395

396
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
397
  {
2,153✔
398
    do {
2,155✔
399
      int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
2,155✔
400
      if (res <= 0) {
2,155✔
401
        return convertIORequestToIOState(res);
683✔
402
      }
683✔
403
      else {
1,472✔
404
        pos += static_cast<size_t>(res);
1,472✔
405
        if (allowIncomplete) {
1,472✔
406
          break;
754✔
407
        }
754✔
408
      }
1,472✔
409
    }
2,155✔
410
    while (pos < toRead);
2,153✔
411
    return IOState::Done;
1,470✔
412
  }
2,153✔
413

414
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
415
  {
×
416
    size_t got = 0;
×
417
    struct timeval start = {0, 0};
×
418
    struct timeval remainingTime = totalTimeout;
×
419
    if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
420
      gettimeofday(&start, nullptr);
×
421
    }
×
422

423
    do {
×
424
      int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
×
425
      if (res <= 0) {
×
426
        handleIORequest(res, readTimeout);
×
427
      }
×
428
      else {
×
429
        got += static_cast<size_t>(res);
×
430
        if (allowIncomplete) {
×
431
          break;
×
432
        }
×
433
      }
×
434

435
      if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
436
        struct timeval now;
×
437
        gettimeofday(&now, nullptr);
×
438
        struct timeval elapsed = now - start;
×
439
        if (now < start || remainingTime < elapsed) {
×
440
          throw runtime_error("Timeout while reading data");
×
441
        }
×
442
        start = now;
×
443
        remainingTime = remainingTime - elapsed;
×
444
      }
×
445
    }
×
446
    while (got < bufferSize);
×
447

448
    return got;
×
449
  }
×
450

451
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
452
  {
×
453
    size_t got = 0;
×
454
    do {
×
455
      int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
×
456
      if (res <= 0) {
×
457
        handleIORequest(res, writeTimeout);
×
458
      }
×
459
      else {
×
460
        got += static_cast<size_t>(res);
×
461
      }
×
462
    }
×
463
    while (got < bufferSize);
×
464

465
    return got;
×
466
  }
×
467

468
  bool isUsable() const override
469
  {
2✔
470
    if (!d_conn) {
2!
471
      return false;
×
472
    }
×
473

474
    char buf;
2✔
475
    int res = SSL_peek(d_conn.get(), &buf, sizeof(buf));
2✔
476
    if (res > 0) {
2!
477
      return true;
×
478
    }
×
479
    try {
2✔
480
      convertIORequestToIOState(res);
2✔
481
      return true;
2✔
482
    }
2✔
483
    catch (...) {
2✔
484
      return false;
×
485
    }
×
486

487
    return false;
×
488
  }
2✔
489

490
  void close() override
491
  {
311✔
492
    if (d_conn) {
311!
493
      SSL_shutdown(d_conn.get());
311✔
494
    }
311✔
495
  }
311✔
496

497
  std::string getServerNameIndication() const override
498
  {
414✔
499
    if (d_conn) {
414!
500
      const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
414✔
501
      if (value) {
414!
502
        return std::string(value);
414✔
503
      }
414✔
504
    }
414✔
505
    return std::string();
×
506
  }
414✔
507

508
  std::vector<uint8_t> getNextProtocol() const override
509
  {
150✔
510
    std::vector<uint8_t> result;
150✔
511
    if (!d_conn) {
150!
512
      return result;
×
513
    }
×
514

515
    const unsigned char* alpn = nullptr;
150✔
516
    unsigned int alpnLen  = 0;
150✔
517
#ifdef HAVE_SSL_GET0_ALPN_SELECTED
150✔
518
    if (alpn == nullptr) {
150!
519
      SSL_get0_alpn_selected(d_conn.get(), &alpn, &alpnLen);
150✔
520
    }
150✔
521
#endif /* HAVE_SSL_GET0_ALPN_SELECTED */
150✔
522
    if (alpn != nullptr && alpnLen > 0) {
150!
523
      result.insert(result.end(), alpn, alpn + alpnLen);
128✔
524
    }
128✔
525
    return result;
150✔
526
  }
150✔
527

528
  LibsslTLSVersion getTLSVersion() const override
529
  {
417✔
530
    auto proto = SSL_version(d_conn.get());
417✔
531
    switch (proto) {
417✔
532
    case TLS1_VERSION:
×
533
      return LibsslTLSVersion::TLS10;
×
534
    case TLS1_1_VERSION:
×
535
      return LibsslTLSVersion::TLS11;
×
536
    case TLS1_2_VERSION:
6✔
537
      return LibsslTLSVersion::TLS12;
6✔
538
#ifdef TLS1_3_VERSION
×
539
    case TLS1_3_VERSION:
411✔
540
      return LibsslTLSVersion::TLS13;
411✔
541
#endif /* TLS1_3_VERSION */
×
542
    default:
×
543
      return LibsslTLSVersion::Unknown;
×
544
    }
417✔
545
  }
417✔
546

547
  bool hasSessionBeenResumed() const override
548
  {
289✔
549
    if (d_conn) {
289!
550
      return SSL_session_reused(d_conn.get()) != 0;
289✔
551
    }
289✔
552
    return false;
×
553
  }
289✔
554

555
  std::vector<std::unique_ptr<TLSSession>> getSessions() override
556
  {
47✔
557
    return std::move(d_tlsSessions);
47✔
558
  }
47✔
559

560
  void setSession(std::unique_ptr<TLSSession>& session) override
561
  {
42✔
562
    auto sess = dynamic_cast<OpenSSLSession*>(session.get());
42✔
563
    if (!sess) {
42!
564
      throw std::runtime_error("Unable to convert OpenSSL session");
×
565
    }
×
566

567
    auto native = sess->getNative();
42✔
568
    auto ret = SSL_set_session(d_conn.get(), native.get());
42✔
569
    if (ret != 1) {
42!
570
      throw std::runtime_error("Error setting up session: " + libssl_get_error_string());
×
571
    }
×
572
    session.reset();
42✔
573
  }
42✔
574

575
  void addNewTicket(SSL_SESSION* session)
576
  {
90✔
577
    d_tlsSessions.push_back(std::make_unique<OpenSSLSession>(std::unique_ptr<SSL_SESSION, void (*)(SSL_SESSION*)>(session, SSL_SESSION_free)));
90✔
578
  }
90✔
579

580
  void enableKTLS()
581
  {
×
582
    d_ktls = true;
×
583
  }
×
584

585
  [[nodiscard]] bool isClient() const
586
  {
1,559✔
587
    return d_isClient;
1,559✔
588
  }
1,559✔
589

590
  static void generateConnectionIndexIfNeeded()
591
  {
81✔
592
    auto init = s_initTLSConnIndex.lock();
81✔
593
    if (*init == true) {
81✔
594
      return;
25✔
595
    }
25✔
596

597
    /* not initialized yet */
598
    s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
56✔
599
    if (s_tlsConnIndex == -1) {
56!
600
      throw std::runtime_error("Error getting an index for TLS connection data");
×
601
    }
×
602

603
    *init = true;
56✔
604
  }
56✔
605

606
  static int getConnectionIndex()
607
  {
442✔
608
    return s_tlsConnIndex;
442✔
609
  }
442✔
610

611
private:
612
  static LockGuarded<bool> s_initTLSConnIndex;
613
  static int s_tlsConnIndex;
614
  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
615
  const std::shared_ptr<const OpenSSLTLSIOCtx> d_tlsCtx; // we need to hold a reference to this to make sure that the context exists for as long as the connection, even if a reload happens in the meantime
616
  std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
617
  const std::string d_hostname;
618
  const timeval d_timeout;
619
  bool d_connected{false};
620
  bool d_ktls{false};
621
  const bool d_isClient{false};
622
};
623

624
LockGuarded<bool> OpenSSLTLSConnection::s_initTLSConnIndex{false};
625
int OpenSSLTLSConnection::s_tlsConnIndex{-1};
626

627
class OpenSSLTLSIOCtx: public TLSCtx, public std::enable_shared_from_this<OpenSSLTLSIOCtx>
628
{
629
  struct Private
630
  {
631
    explicit Private() = default;
632
  };
633

634
public:
635
  static std::shared_ptr<OpenSSLTLSIOCtx> createServerSideContext(TLSFrontend& frontend)
636
  {
51✔
637
    return std::make_shared<OpenSSLTLSIOCtx>(frontend, Private());
51✔
638
  }
51✔
639

640
  static std::shared_ptr<OpenSSLTLSIOCtx> createClientSideContext(const TLSContextParameters& params)
641
  {
30✔
642
    return std::make_shared<OpenSSLTLSIOCtx>(params, Private());
30✔
643
  }
30✔
644

645
  /* server side context */
646
  OpenSSLTLSIOCtx(TLSFrontend& frontend, [[maybe_unused]] Private priv): d_alpnProtos(getALPNVector(frontend.d_alpn, false)), d_feContext(std::make_unique<OpenSSLFrontendContext>(frontend.d_addr, frontend.d_tlsConfig))
647
  {
51✔
648
    OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
51✔
649

650
    d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;
51✔
651

652
    if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) {
51!
653
      /* use our own ticket keys handler so we can rotate them */
654
#if OPENSSL_VERSION_MAJOR >= 3
49✔
655
      SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
49✔
656
#else
657
      SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
658
#endif
659
      libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
49✔
660
    }
49✔
661

662
#ifndef DISABLE_OCSP_STAPLING
51✔
663
    if (!d_feContext->d_ocspResponses.empty()) {
51✔
664
      SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
5✔
665
      SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
5✔
666
    }
5✔
667
#endif /* DISABLE_OCSP_STAPLING */
51✔
668

669
    if (frontend.d_tlsConfig.d_readAhead) {
51!
670
      SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1);
51✔
671
    }
51✔
672

673
    libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &frontend.d_tlsCounters);
51✔
674

675
    libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
51✔
676

677
    if (!frontend.d_tlsConfig.d_keyLogFile.empty()) {
51!
678
      d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx.get(), frontend.d_tlsConfig.d_keyLogFile);
×
679
    }
×
680

681
    try {
51✔
682
      if (frontend.d_tlsConfig.d_ticketKeyFile.empty()) {
51!
683
        handleTicketsKeyRotation(time(nullptr));
51✔
684
      }
51✔
685
      else {
×
686
        OpenSSLTLSIOCtx::loadTicketsKeys(frontend.d_tlsConfig.d_ticketKeyFile);
×
687
      }
×
688
    }
51✔
689
    catch (const std::exception& e) {
51✔
690
      throw;
×
691
    }
×
692
  }
51✔
693

694
  /* client side context */
695
  OpenSSLTLSIOCtx(const TLSContextParameters& params, [[maybe_unused]] Private priv)
696
  {
30✔
697
    int sslOptions =
30✔
698
      SSL_OP_NO_SSLv2 |
30✔
699
      SSL_OP_NO_SSLv3 |
30✔
700
      SSL_OP_NO_COMPRESSION |
30✔
701
      SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
30✔
702
      SSL_OP_SINGLE_DH_USE |
30✔
703
      SSL_OP_SINGLE_ECDH_USE |
30✔
704
#ifdef SSL_OP_IGNORE_UNEXPECTED_EOF
30✔
705
      SSL_OP_IGNORE_UNEXPECTED_EOF |
30✔
706
#endif
30✔
707
      SSL_OP_CIPHER_SERVER_PREFERENCE;
30✔
708
    if (!params.d_enableRenegotiation) {
30!
709
#ifdef SSL_OP_NO_RENEGOTIATION
30✔
710
      sslOptions |= SSL_OP_NO_RENEGOTIATION;
30✔
711
#elif defined(SSL_OP_NO_CLIENT_RENEGOTIATION)
712
      sslOptions |= SSL_OP_NO_CLIENT_RENEGOTIATION;
713
#endif
714
    }
30✔
715

716
    if (params.d_ktls) {
30!
717
#ifdef SSL_OP_ENABLE_KTLS
×
718
      sslOptions |= SSL_OP_ENABLE_KTLS;
×
719
      d_ktls = true;
×
720
#endif /* SSL_OP_ENABLE_KTLS */
×
721
    }
×
722

723
    registerOpenSSLUser();
30✔
724

725
    OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
30✔
726

727
#ifdef HAVE_TLS_CLIENT_METHOD
728
    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
729
#else
730
    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
30✔
731
#endif
30✔
732
    if (!d_tlsCtx) {
30!
733
      ERR_print_errors_fp(stderr);
×
734
      throw std::runtime_error("Error creating TLS context");
×
735
    }
×
736

737
    SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
30✔
738
#if defined(SSL_CTX_set_ecdh_auto)
30✔
739
    SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
30✔
740
#endif
30✔
741

742
    if (!params.d_ciphers.empty()) {
30!
743
      if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) {
×
744
        ERR_print_errors_fp(stderr);
×
745
        throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context");
×
746
      }
×
747
    }
×
748
#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
30✔
749
    if (!params.d_ciphers13.empty()) {
30!
750
      if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) {
×
751
        ERR_print_errors_fp(stderr);
×
752
        throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context");
×
753
      }
×
754
    }
×
755
#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
30✔
756

757
    if (params.d_validateCertificates) {
30✔
758
      if (params.d_caStore.empty())  {
21!
759
        if (SSL_CTX_set_default_verify_paths(d_tlsCtx.get()) != 1) {
×
760
          throw std::runtime_error("Error adding the system's default trusted CAs");
×
761
        }
×
762
      } else {
21✔
763
        if (SSL_CTX_load_verify_locations(d_tlsCtx.get(), params.d_caStore.c_str(), nullptr) != 1) {
21!
764
          throw std::runtime_error("Error adding the trusted CAs file " + params.d_caStore);
×
765
        }
×
766
      }
21✔
767

768
      SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr);
21✔
769
#if (OPENSSL_VERSION_NUMBER < 0x10002000L)
770
      warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
771
#endif
772
    }
21✔
773

774
    /* we need to set SSL_SESS_CACHE_CLIENT for the "new ticket" callback (below) to be called,
775
       but we don't want OpenSSL to cache the session itself so we set SSL_SESS_CACHE_NO_INTERNAL_STORE as well */
776
    SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE);
30✔
777
    SSL_CTX_sess_set_new_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::newTicketFromServerCb);
30✔
778

779
    if (!params.d_keyLogFile.empty()) {
30✔
780
      d_keyLogFile = libssl_set_key_log_file(d_tlsCtx.get(), params.d_keyLogFile);
4✔
781
    }
4✔
782

783
    libssl_set_alpn_protos(d_tlsCtx.get(), getALPNVector(params.d_alpn, true));
30✔
784

785
#ifdef SSL_MODE_RELEASE_BUFFERS
30✔
786
    if (params.d_releaseBuffers) {
30!
787
      SSL_CTX_set_mode(d_tlsCtx.get(), SSL_MODE_RELEASE_BUFFERS);
30✔
788
    }
30✔
789
#endif
30✔
790
  }
30✔
791

792
  OpenSSLTLSIOCtx(const OpenSSLTLSIOCtx&) = delete;
793
  OpenSSLTLSIOCtx(OpenSSLTLSIOCtx&&) = delete;
794
  OpenSSLTLSIOCtx& operator=(const OpenSSLTLSIOCtx&) = delete;
795
  OpenSSLTLSIOCtx& operator=(OpenSSLTLSIOCtx&&) = delete;
796

797
  ~OpenSSLTLSIOCtx() override
798
  {
9✔
799
    d_tlsCtx.reset();
9✔
800
    unregisterOpenSSLUser();
9✔
801
  }
9✔
802

803
#if OPENSSL_VERSION_MAJOR >= 3
804
  static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx, int enc)
805
#else
806
  static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx, int enc)
807
#endif
808
  {
492✔
809
    auto* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
492✔
810
    if (ctx == nullptr) {
492!
811
      return -1;
×
812
    }
×
813

814
    int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
492✔
815
    if (enc == 0) {
492✔
816
      if (ret == 0 || ret == 2) {
30✔
817
        auto* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::getConnectionIndex()));
14✔
818
        if (conn != nullptr) {
14!
819
          if (ret == 0) {
14✔
820
            conn->setUnknownTicketKey();
6✔
821
          }
6✔
822
          else if (ret == 2) {
8!
823
            conn->setResumedFromInactiveTicketKey();
8✔
824
          }
8✔
825
        }
14✔
826
      }
14✔
827
    }
30✔
828

829
    return ret;
492✔
830
  }
492✔
831

832
#ifndef DISABLE_OCSP_STAPLING
833
  static int ocspStaplingCb(SSL* ssl, void* arg)
834
  {
4✔
835
    if (ssl == nullptr || arg == nullptr) {
4!
836
      return SSL_TLSEXT_ERR_NOACK;
×
837
    }
×
838
    const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
4✔
839
    return libssl_ocsp_stapling_callback(ssl, *ocspMap);
4✔
840
  }
4✔
841
#endif /* DISABLE_OCSP_STAPLING */
842

843
  static int newTicketFromServerCb(SSL* ssl, SSL_SESSION* session)
844
  {
90✔
845
    OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(ssl, OpenSSLTLSConnection::getConnectionIndex()));
90✔
846
    if (session == nullptr || conn == nullptr) {
90!
847
      return 0;
×
848
    }
×
849

850
    conn->addNewTicket(session);
90✔
851
    return 1;
90✔
852
  }
90✔
853

854
  SSL_CTX* getOpenSSLContext() const
855
  {
338✔
856
    if (d_feContext) {
338✔
857
      return d_feContext->d_tlsCtx.get();
259✔
858
    }
259✔
859
    return d_tlsCtx.get();
79✔
860
  }
338✔
861

862
  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
863
  {
259✔
864
    handleTicketsKeyRotation(now);
259✔
865

866
    return std::make_unique<OpenSSLTLSConnection>(socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
259✔
867
  }
259✔
868

869
  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
870
  {
79✔
871
    auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
79✔
872
    if (d_ktls) {
79!
873
      conn->enableKTLS();
×
874
    }
×
875
    return conn;
79✔
876
  }
79✔
877

878
  void rotateTicketsKey(time_t now) override
879
  {
95✔
880
    d_feContext->d_ticketKeys.rotateTicketsKey(now);
95✔
881

882
    if (d_ticketsKeyRotationDelay > 0) {
95!
883
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
95✔
884
    }
95✔
885
  }
95✔
886

887
  void loadTicketsKeys(const std::string& keyFile) final
888
  {
12✔
889
    d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
12✔
890

891
    if (d_ticketsKeyRotationDelay > 0) {
12!
892
      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
12✔
893
    }
12✔
894
  }
12✔
895

896
  void loadTicketsKey(const std::string& key) final
897
  {
1✔
898
    d_feContext->d_ticketKeys.loadTicketsKey(key);
1✔
899

900
    if (d_ticketsKeyRotationDelay > 0) {
1!
901
      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
1✔
902
    }
1✔
903
  }
1✔
904

905
  size_t getTicketsKeysCount() override
906
  {
6✔
907
    return d_feContext->d_ticketKeys.getKeysCount();
6✔
908
  }
6✔
909

910
  std::string getName() const override
911
  {
3✔
912
    return "openssl";
3✔
913
  }
3✔
914

915
  bool isServerContext() const
916
  {
×
917
    return d_feContext != nullptr;
×
918
  }
×
919

920
private:
921
  /* called in a client context, if the client advertised more than one ALPN value and the server returned more than one as well, to select the one to use. */
922
  static int alpnServerSelectCallback(SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
923
  {
130✔
924
    if (!arg) {
130!
925
      return SSL_TLSEXT_ERR_ALERT_WARNING;
×
926
    }
×
927
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
928
    OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
130✔
929

930
    const pdns::views::UnsignedCharView inView(in, inlen);
130✔
931
    // Server preference algorithm as per RFC 7301 section 3.2
932
    for (const auto& tentative : obj->d_alpnProtos) {
132✔
933
      size_t pos = 0;
132✔
934
      while (pos < inView.size()) {
137✔
935
        size_t protoLen = inView.at(pos);
133✔
936
        pos++;
133✔
937
        if (protoLen > (inlen - pos)) {
133!
938
          /* something is very wrong */
939
          return SSL_TLSEXT_ERR_ALERT_WARNING;
×
940
        }
×
941

942
        if (tentative.size() == protoLen && memcmp(&inView.at(pos), tentative.data(), tentative.size()) == 0) {
133!
943
          *out = &inView.at(pos);
128✔
944
          *outlen = protoLen;
128✔
945
          return SSL_TLSEXT_ERR_OK;
128✔
946
        }
128✔
947
        pos += protoLen;
5✔
948
      }
5✔
949
    }
132✔
950

951
    return SSL_TLSEXT_ERR_NOACK;
2✔
952
  }
130✔
953

954
  const std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
955
  std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
956
  std::unique_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
957
  pdns::UniqueFilePtr d_keyLogFile{nullptr};
958
  bool d_ktls{false};
959
};
960

961
#endif /* HAVE_LIBSSL */
962

963
#ifdef HAVE_GNUTLS
964
#include <gnutls/gnutls.h>
965
#include <gnutls/x509.h>
966

967
static void safe_memory_lock(void* data, size_t size)
968
{
9✔
969
#ifdef HAVE_LIBSODIUM
9✔
970
  sodium_mlock(data, size);
9✔
971
#endif
9✔
972
}
9✔
973

974
static void safe_memory_release(void* data, size_t size)
975
{
36✔
976
#ifdef HAVE_LIBSODIUM
36✔
977
  sodium_munlock(data, size);
36✔
978
#elif defined(HAVE_EXPLICIT_BZERO)
979
  explicit_bzero(data, size);
980
#elif defined(HAVE_EXPLICIT_MEMSET)
981
  explicit_memset(data, 0, size);
982
#elif defined(HAVE_GNUTLS_MEMSET)
983
  gnutls_memset(data, 0, size);
984
#else
985
  /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
986
  volatile unsigned int volatile_zero_idx = 0;
987
  volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
988

989
  if (size == 0)
990
    return;
991

992
  do {
993
    memset(data, 0, size);
994
  } while (p[volatile_zero_idx] != 0);
995
#endif
996
}
36✔
997

998
class GnuTLSTicketsKey
999
{
1000
public:
1001
  GnuTLSTicketsKey()
1002
  {
6✔
1003
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
6!
1004
      throw std::runtime_error("Error generating tickets key for TLS context");
1005
    }
1006

1007
    safe_memory_lock(d_key.data, d_key.size);
6✔
1008
  }
6✔
1009

1010
  GnuTLSTicketsKey(const std::string& key)
1011
  {
1✔
1012
    /* to be sure we are loading the correct amount of data, which
1013
       may change between versions, let's generate a correct key first */
1014
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
1!
1015
      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
1016
    }
1017

1018
    safe_memory_lock(d_key.data, d_key.size);
1✔
1019
    if (key.size() != d_key.size) {
1!
1020
      safe_memory_release(d_key.data, d_key.size);
1021
      gnutls_free(d_key.data);
1022
      d_key.data = nullptr;
1023
      throw std::runtime_error("Invalid GnuTLS ticket key size");
1024
    }
1025
    memcpy(d_key.data, key.data(), key.size());
1✔
1026
  }
1✔
1027
  GnuTLSTicketsKey(std::ifstream& file)
1028
  {
1029
    /* to be sure we are loading the correct amount of data, which
1030
       may change between versions, let's generate a correct key first */
1031
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
×
1032
      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
1033
    }
1034

1035
    safe_memory_lock(d_key.data, d_key.size);
1036

1037
    try {
1038
      file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
1039

1040
      if (file.fail()) {
×
1041
        throw std::runtime_error("Invalid GnuTLS tickets key file");
1042
      }
1043

1044
    }
1045
    catch (const std::exception& e) {
1046
      safe_memory_release(d_key.data, d_key.size);
1047
      gnutls_free(d_key.data);
1048
      d_key.data = nullptr;
1049
      throw;
1050
    }
1051
  }
1052
  [[nodiscard]] std::string content() const
1053
  {
2✔
1054
    std::string result{};
2✔
1055
    if (d_key.data != nullptr && d_key.size > 0) {
2!
1056
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1057
      result.append(reinterpret_cast<const char*>(d_key.data), d_key.size);
2✔
1058
      safe_memory_lock(result.data(), result.size());
2✔
1059
    }
2✔
1060
    return result;
2✔
1061
  }
2✔
1062

1063
  ~GnuTLSTicketsKey()
1064
  {
3✔
1065
    if (d_key.data != nullptr && d_key.size > 0) {
3!
1066
      safe_memory_release(d_key.data, d_key.size);
3✔
1067
    }
3✔
1068
    gnutls_free(d_key.data);
3✔
1069
    d_key.data = nullptr;
3✔
1070
  }
3✔
1071
  const gnutls_datum_t& getKey() const
1072
  {
14✔
1073
    return d_key;
14✔
1074
  }
14✔
1075

1076
private:
1077
  gnutls_datum_t d_key{nullptr, 0};
1078
};
1079

1080
class GnuTLSSession : public TLSSession
1081
{
1082
public:
1083
  GnuTLSSession(gnutls_datum_t& sess): d_sess(sess)
1084
  {
49✔
1085
    sess.data = nullptr;
49✔
1086
    sess.size = 0;
49✔
1087
  }
49✔
1088

1089
  ~GnuTLSSession() override
1090
  {
31✔
1091
    if (d_sess.data != nullptr && d_sess.size > 0) {
31!
1092
      safe_memory_release(d_sess.data, d_sess.size);
31✔
1093
    }
31✔
1094
    gnutls_free(d_sess.data);
31✔
1095
    d_sess.data = nullptr;
31✔
1096
  }
31✔
1097

1098
  const gnutls_datum_t& getNative()
1099
  {
31✔
1100
    return d_sess;
31✔
1101
  }
31✔
1102

1103
private:
1104
  gnutls_datum_t d_sess{nullptr, 0};
1105
};
1106

1107
class GnuTLSConnection: public TLSConnection
1108
{
1109
public:
1110
  /* server side connection */
1111
  GnuTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_creds(creds), d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
1112
  {
14✔
1113
    unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
14✔
1114
#ifdef GNUTLS_NO_SIGNAL
14✔
1115
    sslOptions |= GNUTLS_NO_SIGNAL;
14✔
1116
#endif
14✔
1117

1118
    d_socket = socket;
14✔
1119

1120
    gnutls_session_t conn;
14✔
1121
    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
14!
1122
      throw std::runtime_error("Error creating TLS connection");
1123
    }
1124

1125
    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
14✔
1126
    conn = nullptr;
14✔
1127

1128
    if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()) != GNUTLS_E_SUCCESS) {
14!
1129
      throw std::runtime_error("Error setting certificate and key to TLS connection");
1130
    }
1131

1132
    if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
14!
1133
      throw std::runtime_error("Error setting ciphers to TLS connection");
1134
    }
1135

1136
    if (enableTickets && d_ticketsKey) {
14!
1137
      const gnutls_datum_t& key = d_ticketsKey->getKey();
14✔
1138
      if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
14!
1139
        throw std::runtime_error("Error setting the tickets key to TLS connection");
1140
      }
1141
    }
14✔
1142

1143
    gnutls_transport_set_int(d_conn.get(), d_socket);
14✔
1144

1145
    /* timeouts are in milliseconds */
1146
    gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
14✔
1147
    gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
14✔
1148
  }
14✔
1149

1150
  /* client-side connection */
1151
  GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, bool validateCerts): d_creds(creds), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
1152
  {
46✔
1153
    unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
46✔
1154
#ifdef GNUTLS_NO_SIGNAL
46✔
1155
    sslOptions |= GNUTLS_NO_SIGNAL;
46✔
1156
#endif
46✔
1157

1158
    d_socket = socket;
46✔
1159

1160
    gnutls_session_t conn;
46✔
1161
    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
46!
1162
      throw std::runtime_error("Error creating TLS connection");
1163
    }
1164

1165
    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
46✔
1166
    conn = nullptr;
46✔
1167

1168
    int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get());
46✔
1169
    if (rc != GNUTLS_E_SUCCESS) {
46!
1170
      throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
1171
    }
1172

1173
    rc = gnutls_priority_set(d_conn.get(), priorityCache);
46✔
1174
    if (rc != GNUTLS_E_SUCCESS) {
46!
1175
      throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc)));
1176
    }
1177

1178
    gnutls_transport_set_int(d_conn.get(), d_socket);
46✔
1179

1180
    /* timeouts are in milliseconds */
1181
    gnutls_handshake_set_timeout(d_conn.get(),  timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
46✔
1182
    gnutls_record_set_timeout(d_conn.get(),  timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
46✔
1183

1184
#ifdef HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
46✔
1185
    if (validateCerts && !d_host.empty()) {
46!
1186
      gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
30✔
1187
      rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
30✔
1188
      if (rc != GNUTLS_E_SUCCESS) {
30!
1189
        throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc)));
1190
      }
1191
    }
30✔
1192
#else
1193
    /* no hostname validation for you */
1194
#endif
1195

1196
    /* allow access to our data in the callbacks */
1197
    gnutls_session_set_ptr(d_conn.get(), this);
46✔
1198
    gnutls_handshake_set_hook_function(d_conn.get(), GNUTLS_HANDSHAKE_NEW_SESSION_TICKET, GNUTLS_HOOK_POST, newTicketFromServerCb);
46✔
1199
  }
46✔
1200

1201
  /* The callback prototype changed in 3.4.0. */
1202
#if GNUTLS_VERSION_NUMBER >= 0x030400
1203
  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */, const gnutls_datum_t* /* msg */)
1204
#else
1205
  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */)
1206
#endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */
1207
  {
49✔
1208
    if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) {
49!
1209
      return 0;
1210
    }
1211

1212
    GnuTLSConnection* conn = reinterpret_cast<GnuTLSConnection*>(gnutls_session_get_ptr(session));
49✔
1213
    if (conn == nullptr) {
49!
1214
      return 0;
1215
    }
1216

1217
    gnutls_datum_t sess{nullptr, 0};
49✔
1218
    auto ret = gnutls_session_get_data2(session, &sess);
49✔
1219
    /* GnuTLS returns a 'fake' ticket of 4 bytes set to zero when there is no ticket available */
1220
    if (ret != GNUTLS_E_SUCCESS || sess.size <= 4) {
49!
1221
      throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
1222
    }
1223
    conn->d_tlsSessions.push_back(std::make_unique<GnuTLSSession>(sess));
49✔
1224
    return 0;
49✔
1225
  }
49✔
1226

1227
  IOState tryConnect(bool fastOpen, [[maybe_unused]] const ComboAddress& remote) override
1228
  {
46✔
1229
    int ret = 0;
46✔
1230

1231
    if (fastOpen) {
46!
1232
#ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
1233
      gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast<struct sockaddr*>(reinterpret_cast<const struct sockaddr*>(&remote)), remote.getSocklen(), 0);
1234
#endif
1235
    }
1236

1237
    do {
46✔
1238
      ret = gnutls_handshake(d_conn.get());
46✔
1239
      if (ret == GNUTLS_E_SUCCESS) {
46!
1240
        d_handshakeDone = true;
1241
        return IOState::Done;
1242
      }
1243
      else if (ret == GNUTLS_E_AGAIN) {
46!
1244
        int direction = gnutls_record_get_direction(d_conn.get());
46✔
1245
        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
46!
1246
      }
46✔
1247
      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
×
1248
        throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1249
      }
1250
    } while (ret == GNUTLS_E_INTERRUPTED);
46!
1251

1252
    throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1253
  }
46✔
1254

1255
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override
1256
  {
1257
    struct timeval start = {0, 0};
1258
    struct timeval remainingTime = timeout;
1259
    if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
×
1260
      gettimeofday(&start, nullptr);
1261
    }
1262

1263
    IOState state;
1264
    do {
1265
      state = tryConnect(fastOpen, remote);
1266
      if (state == IOState::Done) {
×
1267
        return;
1268
      }
1269
      else if (state == IOState::NeedRead) {
×
1270
        int result = waitForData(d_socket, remainingTime.tv_sec, remainingTime.tv_usec);
1271
        if (result <= 0) {
×
1272
          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
1273
        }
1274
      }
1275
      else if (state == IOState::NeedWrite) {
×
1276
        int result = waitForRWData(d_socket, false, remainingTime.tv_sec, remainingTime.tv_usec);
1277
        if (result <= 0) {
×
1278
          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
1279
        }
1280
      }
1281

1282
      if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
×
1283
        struct timeval now;
1284
        gettimeofday(&now, nullptr);
1285
        struct timeval elapsed = now - start;
1286
        if (now < start || remainingTime < elapsed) {
×
1287
          throw runtime_error("Timeout while establishing TLS connection");
1288
        }
1289
        start = now;
1290
        remainingTime = remainingTime - elapsed;
1291
      }
1292
    }
1293
    while (state != IOState::Done);
×
1294
  }
1295

1296
  void doHandshake() override
1297
  {
1298
    int ret = 0;
1299
    do {
1300
      ret = gnutls_handshake(d_conn.get());
1301
      if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
×
1302
        if (d_client) {
×
1303
          throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1304
        }
1305
        else {
1306
          throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
1307
        }
1308
      }
1309
    }
1310
    while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED);
×
1311

1312
    d_handshakeDone = true;
1313
  }
1314

1315
  IOState tryHandshake() override
1316
  {
120✔
1317
    int ret = 0;
120✔
1318

1319
    do {
120✔
1320
      ret = gnutls_handshake(d_conn.get());
120✔
1321
      if (ret == GNUTLS_E_SUCCESS) {
120✔
1322
        d_handshakeDone = true;
54✔
1323
        return IOState::Done;
54✔
1324
      }
54✔
1325
      else if (ret == GNUTLS_E_AGAIN) {
66✔
1326
        int direction = gnutls_record_get_direction(d_conn.get());
60✔
1327
        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
60!
1328
      }
60✔
1329
      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
6!
1330
        if (d_client) {
6!
1331
          std::string error;
6✔
1332
#ifdef HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS
6✔
1333
          if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) {
6!
1334
            gnutls_datum_t out;
6✔
1335
            if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn.get()), gnutls_certificate_type_get(d_conn.get()), &out, 0) == 0) {
6!
1336
              error = " (" + std::string(reinterpret_cast<const char*>(out.data)) + ")";
6✔
1337
              gnutls_free(out.data);
6✔
1338
            }
6✔
1339
          }
6✔
1340
#endif /* HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS */
6✔
1341
          throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)) + error);
6✔
1342
        }
6✔
1343
        else {
1344
          throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1345
        }
1346
      }
6✔
1347
    } while (ret == GNUTLS_E_INTERRUPTED);
120!
1348

1349
    if (d_client) {
×
1350
      throw std::runtime_error("Error establishinging a new connection: " + std::string(gnutls_strerror(ret)));
1351
    }
1352
    else {
1353
      throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
1354
    }
1355
  }
1356

1357
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
1358
  {
298✔
1359
    if (!d_handshakeDone) {
298✔
1360
      /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1361
         we need to keep calling gnutls_handshake() until the handshake has been finished. */
1362
      auto state = tryHandshake();
92✔
1363
      if (state != IOState::Done) {
92✔
1364
        return state;
46✔
1365
      }
46✔
1366
    }
92✔
1367

1368
    do {
252✔
1369
      ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
252✔
1370
      if (res == 0) {
252!
1371
        throw std::runtime_error("Error writing to TLS connection");
1372
      }
1373
      else if (res > 0) {
252✔
1374
        pos += static_cast<size_t>(res);
246✔
1375
      }
246✔
1376
      else if (res < 0) {
6!
1377
        if (gnutls_error_is_fatal(res)) {
×
1378
          throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
1379
        }
1380
        else if (res == GNUTLS_E_AGAIN) {
×
1381
          return IOState::NeedWrite;
1382
        }
1383
        vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1384
      }
1385
    }
252✔
1386
    while (pos < toWrite);
252!
1387
    return IOState::Done;
252✔
1388
  }
252✔
1389

1390
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
1391
  {
501✔
1392
    if (!d_handshakeDone) {
501!
1393
      /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1394
         we need to keep calling gnutls_handshake() until the handshake has been finished. */
1395
      auto state = tryHandshake();
1396
      if (state != IOState::Done) {
×
1397
        return state;
1398
      }
1399
    }
1400

1401
    do {
501✔
1402
      ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
501✔
1403
      if (res == 0) {
501✔
1404
        throw std::runtime_error("EOF while reading from TLS connection");
3✔
1405
      }
3✔
1406
      else if (res > 0) {
498✔
1407
        pos += static_cast<size_t>(res);
346✔
1408
        if (allowIncomplete) {
346✔
1409
          break;
70✔
1410
        }
70✔
1411
      }
346✔
1412
      else if (res < 0) {
152!
1413
        if (gnutls_error_is_fatal(res)) {
152✔
1414
          throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
33✔
1415
        }
33✔
1416
        else if (res == GNUTLS_E_AGAIN) {
119!
1417
          return IOState::NeedRead;
119✔
1418
        }
119✔
1419
        vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1420
      }
1421
    }
501✔
1422
    while (pos < toRead);
501!
1423
    return IOState::Done;
346✔
1424
  }
501✔
1425

1426
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
1427
  {
1428
    size_t got = 0;
1429
    struct timeval start{0,0};
1430
    struct timeval  remainingTime = totalTimeout;
1431
    if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
1432
      gettimeofday(&start, nullptr);
1433
    }
1434

1435
    do {
1436
      ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
1437
      if (res == 0) {
×
1438
        throw std::runtime_error("EOF while reading from TLS connection");
1439
      }
1440
      else if (res > 0) {
×
1441
        got += static_cast<size_t>(res);
1442
        if (allowIncomplete) {
×
1443
          break;
1444
        }
1445
      }
1446
      else if (res < 0) {
×
1447
        if (gnutls_error_is_fatal(res)) {
×
1448
          throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
1449
        }
1450
        else if (res == GNUTLS_E_AGAIN) {
×
1451
          int result = waitForData(d_socket, readTimeout.tv_sec, readTimeout.tv_usec);
1452
          if (result <= 0) {
×
1453
            throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
1454
          }
1455
        }
1456
        else {
1457
          vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
×
1458
        }
1459
      }
1460

1461
      if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
1462
        struct timeval now;
1463
        gettimeofday(&now, nullptr);
1464
        struct timeval elapsed = now - start;
1465
        if (now < start || remainingTime < elapsed) {
×
1466
          throw runtime_error("Timeout while reading data");
1467
        }
1468
        start = now;
1469
        remainingTime = remainingTime - elapsed;
1470
      }
1471
    }
1472
    while (got < bufferSize);
×
1473

1474
    return got;
1475
  }
1476

1477
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
1478
  {
1479
    size_t got = 0;
1480

1481
    do {
1482
      ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
1483
      if (res == 0) {
×
1484
        throw std::runtime_error("Error writing to TLS connection");
1485
      }
1486
      else if (res > 0) {
×
1487
        got += static_cast<size_t>(res);
1488
      }
1489
      else if (res < 0) {
×
1490
        if (gnutls_error_is_fatal(res)) {
×
1491
          throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
1492
        }
1493
        else if (res == GNUTLS_E_AGAIN) {
×
1494
          int result = waitForRWData(d_socket, false, writeTimeout.tv_sec, writeTimeout.tv_usec);
1495
          if (result <= 0) {
×
1496
            throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
1497
          }
1498
        }
1499
        else {
1500
          vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1501
        }
1502
      }
1503
    }
1504
    while (got < bufferSize);
×
1505

1506
    return got;
1507
  }
1508

1509
  bool isUsable() const override
1510
  {
2✔
1511
    if (!d_conn) {
2!
1512
      return false;
1513
    }
1514

1515
    /* as far as I can tell we can't peek so we cannot do better */
1516
    return isTCPSocketUsable(d_socket);
2✔
1517
  }
2✔
1518

1519
  std::string getServerNameIndication() const override
1520
  {
112✔
1521
    if (d_conn) {
112!
1522
      unsigned int type;
112✔
1523
      size_t name_len = 256;
112✔
1524
      std::string sni;
112✔
1525
      sni.resize(name_len);
112✔
1526

1527
      int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
112✔
1528
      if (res == GNUTLS_E_SUCCESS) {
112!
1529
        sni.resize(name_len);
112✔
1530
        return sni;
112✔
1531
      }
112✔
1532
    }
112✔
1533
    return std::string();
1534
  }
112✔
1535

1536
  std::vector<uint8_t> getNextProtocol() const override
1537
  {
1538
    std::vector<uint8_t> result;
1539
    if (!d_conn) {
×
1540
      return result;
1541
    }
1542
    gnutls_datum_t next;
1543
    if (gnutls_alpn_get_selected_protocol(d_conn.get(), &next) != GNUTLS_E_SUCCESS) {
×
1544
      return result;
1545
    }
1546
    result.insert(result.end(), next.data, next.data + next.size);
1547
    return result;
1548
  }
1549

1550
  LibsslTLSVersion getTLSVersion() const override
1551
  {
112✔
1552
    auto proto = gnutls_protocol_get_version(d_conn.get());
112✔
1553
    switch (proto) {
112✔
1554
    case GNUTLS_TLS1_0:
112!
1555
      return LibsslTLSVersion::TLS10;
1556
    case GNUTLS_TLS1_1:
112!
1557
      return LibsslTLSVersion::TLS11;
1558
    case GNUTLS_TLS1_2:
3✔
1559
      return LibsslTLSVersion::TLS12;
3✔
1560
#if GNUTLS_VERSION_NUMBER >= 0x030603
1561
    case GNUTLS_TLS1_3:
109✔
1562
      return LibsslTLSVersion::TLS13;
109✔
1563
#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
1564
    default:
112!
1565
      return LibsslTLSVersion::Unknown;
1566
    }
112✔
1567
  }
112✔
1568

1569
  bool hasSessionBeenResumed() const override
1570
  {
46✔
1571
    if (d_conn) {
46!
1572
      return gnutls_session_is_resumed(d_conn.get()) != 0;
46✔
1573
    }
46✔
1574
    return false;
1575
  }
46✔
1576

1577
  std::vector<std::unique_ptr<TLSSession>> getSessions() override
1578
  {
36✔
1579
    return std::move(d_tlsSessions);
36✔
1580
  }
36✔
1581

1582
  void setSession(std::unique_ptr<TLSSession>& session) override
1583
  {
31✔
1584
    auto sess = dynamic_cast<GnuTLSSession*>(session.get());
31✔
1585
    if (!sess) {
31!
1586
      throw std::runtime_error("Unable to convert GnuTLS session");
1587
    }
1588

1589
    auto native = sess->getNative();
31✔
1590
    auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size);
31✔
1591
    if (ret != GNUTLS_E_SUCCESS) {
31!
1592
      throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret)));
1593
    }
1594
    session.reset();
31✔
1595
  }
31✔
1596

1597
  void close() override
1598
  {
54✔
1599
    if (d_conn) {
54!
1600
      gnutls_bye(d_conn.get(), GNUTLS_SHUT_RDWR);
54✔
1601
    }
54✔
1602
  }
54✔
1603

1604
  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos)
1605
  {
60✔
1606
    std::vector<gnutls_datum_t> values;
60✔
1607
    values.reserve(protos.size());
60✔
1608
    for (const auto& proto : protos) {
60✔
1609
      gnutls_datum_t value;
60✔
1610
      value.data = const_cast<uint8_t*>(proto.data());
60✔
1611
      value.size = proto.size();
60✔
1612
      values.push_back(value);
60✔
1613
    }
60✔
1614
    unsigned int flags = 0;
60✔
1615
#if GNUTLS_VERSION_NUMBER >= 0x030500
60✔
1616
    flags |= GNUTLS_ALPN_MANDATORY;
60✔
1617
#elif defined(GNUTLS_ALPN_MAND)
1618
    flags |= GNUTLS_ALPN_MAND;
1619
#endif
1620
    return gnutls_alpn_set_protocols(d_conn.get(), values.data(), values.size(), flags);
60✔
1621
  }
60✔
1622

1623
  std::vector<int> getAsyncFDs() override
1624
  {
14✔
1625
    return {};
14✔
1626
  }
14✔
1627

1628
private:
1629
  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1630
  std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
1631
  std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
1632
  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
1633
  std::string d_host;
1634
  const bool d_client{false};
1635
  bool d_handshakeDone{false};
1636
};
1637

1638
class GnuTLSIOCtx: public TLSCtx
1639
{
1640
public:
1641
  /* server side context */
1642
  GnuTLSIOCtx(TLSFrontend& frontend): d_protos(getALPNVector(frontend.d_alpn, false)), d_enableTickets(frontend.d_tlsConfig.d_enableTickets)
1643
  {
6✔
1644
    int rc = 0;
6✔
1645
    d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;
6✔
1646

1647
    gnutls_certificate_credentials_t creds;
6✔
1648
    rc = gnutls_certificate_allocate_credentials(&creds);
6✔
1649
    if (rc != GNUTLS_E_SUCCESS) {
6!
1650
      throw std::runtime_error("Error allocating credentials for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1651
    }
1652

1653
    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
6✔
1654
    creds = nullptr;
6✔
1655

1656
    for (const auto& pair : frontend.d_tlsConfig.d_certKeyPairs) {
6✔
1657
      rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.d_cert.c_str(), pair.d_key->c_str(), GNUTLS_X509_FMT_PEM);
6✔
1658
      if (rc != GNUTLS_E_SUCCESS) {
6!
1659
        throw std::runtime_error("Error loading certificate ('" + pair.d_cert + "') and key ('" + pair.d_key.value() + "') for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1660
      }
1661
    }
6✔
1662

1663
#ifndef DISABLE_OCSP_STAPLING
6✔
1664
    size_t count = 0;
6✔
1665
    for (const auto& file : frontend.d_tlsConfig.d_ocspFiles) {
6✔
1666
      rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
3✔
1667
      if (rc != GNUTLS_E_SUCCESS) {
3✔
1668
        warnlog("Error loading OCSP response from file '%s' for certificate ('%s') and key ('%s') for TLS context on %s: %s", file, frontend.d_tlsConfig.d_certKeyPairs.at(count).d_cert, frontend.d_tlsConfig.d_certKeyPairs.at(count).d_key.value(), frontend.d_addr.toStringWithPort(), gnutls_strerror(rc));
1✔
1669
      }
1✔
1670
      ++count;
3✔
1671
    }
3✔
1672
#endif /* DISABLE_OCSP_STAPLING */
6✔
1673

1674
#if GNUTLS_VERSION_NUMBER >= 0x030600
6✔
1675
    rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
6✔
1676
    if (rc != GNUTLS_E_SUCCESS) {
6!
1677
      throw std::runtime_error("Error setting DH params for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1678
    }
1679
#endif
6✔
1680

1681
    rc = gnutls_priority_init(&d_priorityCache, frontend.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : frontend.d_tlsConfig.d_ciphers.c_str(), nullptr);
6!
1682
    if (rc != GNUTLS_E_SUCCESS) {
6!
1683
      throw std::runtime_error("Error setting up TLS cipher preferences to '" + frontend.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + frontend.d_addr.toStringWithPort());
1684
    }
1685

1686
    try {
6✔
1687
      if (frontend.d_tlsConfig.d_ticketKeyFile.empty()) {
6!
1688
        handleTicketsKeyRotation(time(nullptr));
6✔
1689
      }
6✔
1690
      else {
1691
        GnuTLSIOCtx::loadTicketsKeys(frontend.d_tlsConfig.d_ticketKeyFile);
1692
      }
1693
    }
6✔
1694
    catch(const std::runtime_error& e) {
6✔
1695
      throw std::runtime_error("Error generating tickets key for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + e.what());
1696
    }
1697
  }
6✔
1698

1699
  /* client side context */
1700
  GnuTLSIOCtx(const TLSContextParameters& params): d_protos(getALPNVector(params.d_alpn, true)), d_contextParameters(std::make_unique<TLSContextParameters>(params)), d_validateCerts(params.d_validateCertificates)
1701
  {
18✔
1702
    int rc = 0;
18✔
1703

1704
    gnutls_certificate_credentials_t creds;
18✔
1705
    rc = gnutls_certificate_allocate_credentials(&creds);
18✔
1706
    if (rc != GNUTLS_E_SUCCESS) {
18!
1707
      throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1708
    }
1709

1710
    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
18✔
1711
    creds = nullptr;
18✔
1712

1713
    if (params.d_validateCertificates) {
18✔
1714
      if (params.d_caStore.empty()) {
14!
1715
#if GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703
1716
        /* see https://gitlab.com/gnutls/gnutls/-/issues/1277 */
1717
        std::cerr<<"Warning: GnuTLS 3.7.0 - 3.7.2 have a memory leak when validating server certificates in some configurations (PKCS11 support enabled, and a default PKCS11 trust store), please consider upgrading GnuTLS, using the OpenSSL provider for outgoing connections, or explicitly setting a CA store"<<std::endl;
1718
#endif /* GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703 */
1719
        rc = gnutls_certificate_set_x509_system_trust(d_creds.get());
1720
        if (rc < 0) {
×
1721
          throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1722
        }
1723
      }
1724
      else {
14✔
1725
        rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM);
14✔
1726
        if (rc < 0) {
14!
1727
          throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1728
        }
1729
      }
14✔
1730
    }
14✔
1731

1732
    rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr);
18!
1733
    if (rc != GNUTLS_E_SUCCESS) {
18!
1734
      throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")");
1735
    }
1736
  }
18✔
1737

1738
  ~GnuTLSIOCtx() override
1739
  {
2✔
1740
    d_creds.reset();
2✔
1741

1742
    if (d_priorityCache) {
2!
1743
      gnutls_priority_deinit(d_priorityCache);
2✔
1744
    }
2✔
1745
  }
2✔
1746

1747
  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
1748
  {
14✔
1749
    handleTicketsKeyRotation(now);
14✔
1750

1751
    std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
14✔
1752
    {
14✔
1753
      ticketsKey = *(d_ticketsKey.read_lock());
14✔
1754
    }
14✔
1755

1756
    auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds, d_priorityCache, ticketsKey, d_enableTickets);
14✔
1757
    if (!d_protos.empty()) {
14!
1758
      connection->setALPNProtos(d_protos);
14✔
1759
    }
14✔
1760
    return connection;
14✔
1761
  }
14✔
1762

1763
  static std::shared_ptr<gnutls_certificate_credentials_st> getPerThreadCredentials(bool validate, const std::string& caStore)
1764
  {
46✔
1765
    static thread_local std::map<std::pair<bool, std::string>, std::shared_ptr<gnutls_certificate_credentials_st>> t_credentials;
46✔
1766
    auto& entry = t_credentials[{validate, caStore}];
46✔
1767
    if (!entry) {
46✔
1768
      gnutls_certificate_credentials_t creds;
18✔
1769
      int rc = gnutls_certificate_allocate_credentials(&creds);
18✔
1770
      if (rc != GNUTLS_E_SUCCESS) {
18!
1771
        throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1772
      }
1773

1774
      entry = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
18✔
1775
      creds = nullptr;
18✔
1776

1777
      if (validate) {
18✔
1778
        if (caStore.empty()) {
13!
1779
          rc = gnutls_certificate_set_x509_system_trust(entry.get());
1780
          if (rc < 0) {
×
1781
            throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1782
          }
1783
        }
1784
        else {
13✔
1785
          rc = gnutls_certificate_set_x509_trust_file(entry.get(), caStore.c_str(), GNUTLS_X509_FMT_PEM);
13✔
1786
          if (rc < 0) {
13!
1787
            throw std::runtime_error("Error adding '" + caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1788
          }
1789
        }
13✔
1790
      }
13✔
1791
    }
18✔
1792
    return entry;
46✔
1793
  }
46✔
1794

1795
  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool, int socket, const struct timeval& timeout) override
1796
  {
46✔
1797
    auto creds = getPerThreadCredentials(d_contextParameters->d_validateCertificates, d_contextParameters->d_caStore);
46✔
1798
    auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, creds, d_priorityCache, d_validateCerts);
46✔
1799
    if (!d_protos.empty()) {
46!
1800
      connection->setALPNProtos(d_protos);
46✔
1801
    }
46✔
1802
    return connection;
46✔
1803
  }
46✔
1804

1805
  void addTicketsKey(time_t now, std::shared_ptr<GnuTLSTicketsKey>&& newKey)
1806
  {
7✔
1807
    if (!d_enableTickets) {
7!
1808
      return;
1809
    }
1810

1811
    {
7✔
1812
      *(d_ticketsKey.write_lock()) = std::move(newKey);
7✔
1813
    }
7✔
1814

1815
    if (d_ticketsKeyRotationDelay > 0) {
7!
1816
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
7✔
1817
    }
7✔
1818

1819
    if (TLSCtx::hasTicketsKeyAddedHook()) {
7✔
1820
      auto ticketsKey = *(d_ticketsKey.read_lock());
2✔
1821
      auto content = ticketsKey->content();
2✔
1822
      TLSCtx::getTicketsKeyAddedHook()(content);
2✔
1823
      safe_memory_release(content.data(), content.size());
2✔
1824
    }
2✔
1825
  }
7✔
1826
  void rotateTicketsKey(time_t now) override
1827
  {
6✔
1828
    if (!d_enableTickets) {
6!
1829
      return;
1830
    }
1831

1832
    auto newKey = std::make_shared<GnuTLSTicketsKey>();
6✔
1833
    addTicketsKey(now, std::move(newKey));
6✔
1834
  }
6✔
1835
  void loadTicketsKey(const std::string& key) final
1836
  {
1✔
1837
    if (!d_enableTickets) {
1!
1838
      return;
1839
    }
1840

1841
    auto newKey = std::make_shared<GnuTLSTicketsKey>(key);
1✔
1842
    addTicketsKey(time(nullptr), std::move(newKey));
1✔
1843
  }
1✔
1844

1845
  void loadTicketsKeys(const std::string& keyFile) final
1846
  {
1847
    if (!d_enableTickets) {
×
1848
      return;
1849
    }
1850

1851
    std::ifstream file(keyFile);
1852
    auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1853
    addTicketsKey(time(nullptr), std::move(newKey));
1854
    file.close();
1855
  }
1856

1857
  size_t getTicketsKeysCount() override
1858
  {
1859
    return *(d_ticketsKey.read_lock()) != nullptr ? 1 : 0;
×
1860
  }
1861

1862
  std::string getName() const override
1863
  {
3✔
1864
    return "gnutls";
3✔
1865
  }
3✔
1866

1867
private:
1868
  /* client context parameters */
1869
  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1870
  const std::vector<std::vector<uint8_t>> d_protos;
1871
  std::unique_ptr<TLSContextParameters> d_contextParameters{nullptr};
1872
  gnutls_priority_t d_priorityCache{nullptr};
1873
  SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};
1874
  bool d_enableTickets{true};
1875
  bool d_validateCerts{true};
1876
};
1877

1878
#endif /* HAVE_GNUTLS */
1879

1880
#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
1881

1882
bool TLSFrontend::setupTLS()
1883
{
57✔
1884
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
57✔
1885
  std::shared_ptr<TLSCtx> newCtx{nullptr};
57✔
1886
  /* get the "best" available provider */
1887
#if defined(HAVE_GNUTLS)
57✔
1888
  if (d_provider == "gnutls") {
57✔
1889
    newCtx = std::make_shared<GnuTLSIOCtx>(*this);
6✔
1890
  }
6✔
1891
#endif /* HAVE_GNUTLS */
57✔
1892
#if defined(HAVE_LIBSSL)
57✔
1893
  if (d_provider == "openssl") {
57✔
1894
    newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
17✔
1895
  }
17✔
1896
#endif /* HAVE_LIBSSL */
57✔
1897

1898
  if (!newCtx) {
57✔
1899
#if defined(HAVE_LIBSSL)
34✔
1900
    newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
34✔
1901
#elif defined(HAVE_GNUTLS)
1902
    newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1903
#else
1904
#error "TLS support needed but neither libssl nor GnuTLS were selected"
1905
#endif
1906
  }
34✔
1907

1908
  std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
57✔
1909
#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
57✔
1910
  return true;
57✔
1911
}
57✔
1912

1913
std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params)
1914
{
48✔
1915
#ifdef HAVE_DNS_OVER_TLS
48✔
1916
  /* get the "best" available provider */
1917
  if (!params.d_provider.empty()) {
48✔
1918
#if defined(HAVE_GNUTLS)
36✔
1919
    if (params.d_provider == "gnutls") {
36✔
1920
      return std::make_shared<GnuTLSIOCtx>(params);
18✔
1921
    }
18✔
1922
#endif /* HAVE_GNUTLS */
18✔
1923
#if defined(HAVE_LIBSSL)
23✔
1924
    if (params.d_provider == "openssl") {
23!
1925
      return OpenSSLTLSIOCtx::createClientSideContext(params);
23✔
1926
    }
23✔
1927
#endif /* HAVE_LIBSSL */
23✔
1928
  }
23✔
1929

1930
#if defined(HAVE_LIBSSL)
7✔
1931
  return OpenSSLTLSIOCtx::createClientSideContext(params);
7✔
1932
#elif defined(HAVE_GNUTLS)
1933
  return std::make_shared<GnuTLSIOCtx>(params);
1934
#else
1935
#error "DNS over TLS support needed but neither libssl nor GnuTLS were selected"
1936
#endif
1937

1938
#endif /* HAVE_DNS_OVER_TLS */
×
1939
  return nullptr;
×
1940
}
48✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc