• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 15468206343

05 Jun 2025 01:23PM UTC coverage: 63.692% (-0.008%) from 63.7%
15468206343

push

github

web-flow
Merge pull request #15607 from miodvallat/too_much_sugar

Try harder matching command names in pdnsutil

42373 of 101390 branches covered (41.79%)

Branch coverage included in aggregate %.

130649 of 170266 relevant lines covered (76.73%)

4341228.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.64
/pdns/tcpiohandler.cc
1

2
#include "config.h"
3
#include "dolog.hh"
4
#include "iputils.hh"
5
#include "lock.hh"
6
#include "tcpiohandler.hh"
7

8
const bool TCPIOHandler::s_disableConnectForUnitTests = false;
9

10
#ifdef HAVE_LIBSODIUM
11
#include <sodium.h>
12
#endif /* HAVE_LIBSODIUM */
13

14
TLSCtx::tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr};
15

16
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
17
static std::vector<std::vector<uint8_t>> getALPNVector(TLSFrontend::ALPN alpn, bool client)
18
{
133✔
19
  if (alpn == TLSFrontend::ALPN::DoT) {
133✔
20
    /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */
21
    return std::vector<std::vector<uint8_t>>{{'d', 'o', 't'}};
57✔
22
  }
57✔
23
  if (alpn == TLSFrontend::ALPN::DoH) {
76✔
24
    if (client) {
64✔
25
      /* we want to set the ALPN to h2, if only to mitigate the ALPACA attack */
26
      return std::vector<std::vector<uint8_t>>{{'h', '2'}};
28✔
27
    }
28✔
28
    /* For server contexts, we want to set the ALPN for DoH (note that h2o sets it own ALPN values):
29
       - HTTP/1.1 so that the OpenSSL callback ALPN accepts it, letting us later return a static response
30
       - HTTP/2
31
    */
32
    return std::vector<std::vector<uint8_t>>{{'h', '2'},{'h', 't', 't', 'p', '/', '1', '.', '1'}};
36✔
33
  }
64✔
34
  return {};
12✔
35
}
76✔
36

37
#ifdef HAVE_LIBSSL
38

39
namespace {
40
bool shouldDoVerboseLogging()
41
{
7✔
42
#ifdef DNSDIST
7✔
43
  return dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose;
7✔
44
#elif defined(RECURSOR)
45
  return false;
46
#else
47
  return true;
48
#endif
49
}
7✔
50
}
51

52
#include <openssl/conf.h>
53
#include <openssl/err.h>
54
#include <openssl/rand.h>
55
#include <openssl/ssl.h>
56
#include <openssl/x509v3.h>
57

58
#include "libssl.hh"
59

60
static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* arg);
61

62
class OpenSSLFrontendContext
63
{
64
public:
65
  OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
67✔
66
  {
67✔
67
    registerOpenSSLUser();
67✔
68

69
    auto [ctx, warnings] = libssl_init_server_context(tlsConfig);
67✔
70
    for (const auto& warning : warnings) {
67✔
71
      warnlog("%s", warning);
2✔
72
    }
2✔
73
    // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): it cannot be initialized before calling libssl_init_server_context()
74
    d_ocspResponses = std::move(ctx.d_ocspResponses);
67✔
75
    // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): it cannot be initialized before calling libssl_init_server_context()
76
    d_tlsCtx = std::move(ctx.d_defaultContext);
67✔
77
    // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): it cannot be initialized before calling libssl_init_server_context()
78
    d_sniMap = std::move(ctx.d_sniMap);
67✔
79
    for (auto& entry : d_sniMap) {
136✔
80
      SSL_CTX_set_tlsext_servername_callback(entry.second.get(), &sni_server_name_callback);
136✔
81
    }
136✔
82

83
    if (!d_tlsCtx) {
67!
84
      ERR_print_errors_fp(stderr);
×
85
      throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
×
86
    }
×
87
  }
67✔
88

89
  void cleanup()
90
  {
×
91
    d_tlsCtx.reset();
×
92

×
93
    unregisterOpenSSLUser();
×
94
  }
×
95

96
  OpenSSLTLSTicketKeysRing d_ticketKeys;
97
  std::map<int, std::string> d_ocspResponses;
98
  pdns::libssl::ServerContext::SNIToContextMap d_sniMap;
99
  std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr};
100
  pdns::UniqueFilePtr d_keyLogFile{nullptr};
101
};
102

103

104
static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* /* arg */)
105
{
543✔
106
  const auto* serverName = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
543✔
107
  if (serverName == nullptr) {
543✔
108
    return SSL_TLSEXT_ERR_NOACK;
3✔
109
  }
3✔
110
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
111
  auto* frontendCtx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(ssl));
540✔
112
  if (frontendCtx == nullptr) {
540✔
113
    return SSL_TLSEXT_ERR_OK;
4✔
114
  }
4✔
115

116
  auto serverNameView = std::string_view(serverName);
536✔
117

118
  auto mapIt = frontendCtx->d_sniMap.find(serverNameView);
536✔
119
  if (mapIt == frontendCtx->d_sniMap.end()) {
536✔
120
    /* keep the default certificate */
121
    return SSL_TLSEXT_ERR_OK;
1✔
122
  }
1✔
123

124
  /* if it fails there is nothing we can do,
125
     let's hope OpenSSL will fall back to the existing,
126
     default certificate*/
127
  SSL_set_SSL_CTX(ssl, mapIt->second.get());
535✔
128
  return SSL_TLSEXT_ERR_OK;
535✔
129
}
536✔
130

131
class OpenSSLSession : public TLSSession
132
{
133
public:
134
  OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>&& sess): d_sess(std::move(sess))
131✔
135
  {
141✔
136
  }
141✔
137

138
  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative()
139
  {
81✔
140
    return std::move(d_sess);
81✔
141
  }
81✔
142

143
private:
144
  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
145
};
146

147
class OpenSSLTLSIOCtx;
148

149
class OpenSSLTLSConnection: public TLSConnection
150
{
151
public:
152
  /* server side connection */
153
  OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_timeout(timeout)
545✔
154
  {
545✔
155
    d_socket = socket;
545✔
156

157
    if (!d_conn) {
545!
158
      vinfolog("Error creating TLS object");
×
159
      if (shouldDoVerboseLogging()) {
×
160
        ERR_print_errors_fp(stderr);
×
161
      }
×
162
      throw std::runtime_error("Error creating TLS object");
×
163
    }
×
164

165
    if (!SSL_set_fd(d_conn.get(), d_socket)) {
545!
166
      throw std::runtime_error("Error assigning socket");
×
167
    }
×
168

169
    SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
545✔
170
  }
545✔
171

172
  /* client-side connection */
173
  OpenSSLTLSConnection(std::string hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_hostname(std::move(hostname)), d_timeout(timeout), d_isClient(true)
134✔
174
  {
139✔
175
    d_socket = socket;
139✔
176

177
    if (!d_conn) {
139!
178
      vinfolog("Error creating TLS object");
×
179
      if (shouldDoVerboseLogging()) {
×
180
        ERR_print_errors_fp(stderr);
×
181
      }
×
182
      throw std::runtime_error("Error creating TLS object");
×
183
    }
×
184

185
    if (!SSL_set_fd(d_conn.get(), d_socket)) {
139!
186
      throw std::runtime_error("Error assigning socket");
×
187
    }
×
188

189
    /* set outgoing Server Name Indication */
190
    if (!d_hostname.empty() && SSL_set_tlsext_host_name(d_conn.get(), d_hostname.c_str()) != 1) {
139!
191
      throw std::runtime_error("Error setting TLS SNI to " + d_hostname);
×
192
    }
×
193

194
    if (hostIsAddr) {
139✔
195
#if (OPENSSL_VERSION_NUMBER >= 0x10002000L)
2✔
196
      X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
2✔
197
      /* Enable automatic IP checks */
198
      X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
2✔
199
      if (X509_VERIFY_PARAM_set1_ip_asc(param, d_hostname.c_str()) != 1) {
2!
200
        throw std::runtime_error("Error setting TLS IP for certificate validation");
×
201
      }
×
202
#else
203
      /* no validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
204
#endif
205
    }
2✔
206
    else {
137✔
207
#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && defined(HAVE_SSL_SET_HOSTFLAGS) // grrr libressl
137✔
208
      SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
137✔
209
      if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
137!
210
        throw std::runtime_error("Error setting TLS hostname for certificate validation");
×
211
      }
×
212
#elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
213
      X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
214
      /* Enable automatic hostname checks */
215
      X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
216
      if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) {
217
        throw std::runtime_error("Error setting TLS hostname for certificate validation");
218
      }
219
#else
220
      /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
221
#endif
222
    }
137✔
223

224
    SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
139✔
225
  }
139✔
226

227
  std::vector<int> getAsyncFDs() override
228
  {
348✔
229
    std::vector<int> results;
348✔
230
#ifdef SSL_MODE_ASYNC
348✔
231
    if (SSL_waiting_for_async(d_conn.get()) != 1) {
348!
232
      return results;
348✔
233
    }
348✔
234

235
    OSSL_ASYNC_FD fds[32];
348✔
236
    size_t numfds = sizeof(fds)/sizeof(*fds);
×
237
    SSL_get_all_async_fds(d_conn.get(), nullptr, &numfds);
×
238
    if (numfds == 0) {
×
239
      return results;
×
240
    }
×
241

242
    SSL_get_all_async_fds(d_conn.get(), fds, &numfds);
×
243
    results.reserve(numfds);
×
244
    for (size_t idx = 0; idx < numfds; idx++) {
×
245
      results.push_back(fds[idx]);
×
246
    }
×
247
#endif
×
248
    return results;
×
249
  }
×
250

251
  IOState convertIORequestToIOState(int res) const
252
  {
2,387✔
253
    int error = SSL_get_error(d_conn.get(), res);
2,387✔
254
    if (error == SSL_ERROR_WANT_READ) {
2,387✔
255
      return IOState::NeedRead;
1,806✔
256
    }
1,806✔
257
    else if (error == SSL_ERROR_WANT_WRITE) {
581✔
258
      return IOState::NeedWrite;
12✔
259
    }
12✔
260
    else if (error == SSL_ERROR_SYSCALL) {
569✔
261
      if (errno == 0) {
24!
262
        throw std::runtime_error("TLS connection closed by remote end");
×
263
      }
×
264
      else {
24✔
265
        throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
24✔
266
      }
24✔
267
    }
24✔
268
    else if (error == SSL_ERROR_ZERO_RETURN) {
545✔
269
      throw std::runtime_error("TLS connection closed by remote end");
538✔
270
    }
538✔
271
#ifdef SSL_MODE_ASYNC
7✔
272
    else if (error == SSL_ERROR_WANT_ASYNC) {
7!
273
      return IOState::Async;
×
274
    }
×
275
#endif
7✔
276
    else {
7✔
277
      if (shouldDoVerboseLogging()) {
7✔
278
        throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error) + ") " + libssl_get_error_string());
1✔
279
      } else {
6✔
280
        throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
6✔
281
      }
6✔
282
    }
7✔
283
  }
2,387✔
284

285
  void handleIORequest(int res, const struct timeval& timeout)
286
  {
5✔
287
    auto state = convertIORequestToIOState(res);
5✔
288
    if (state == IOState::NeedRead) {
5✔
289
      res = waitForData(d_socket, timeout.tv_sec, timeout.tv_usec);
4✔
290
      if (res == 0) {
4!
291
        throw std::runtime_error("Timeout while reading from TLS connection");
×
292
      }
×
293
      else if (res < 0) {
4!
294
        throw std::runtime_error("Error waiting to read from TLS connection");
×
295
      }
×
296
    }
4✔
297
    else if (state == IOState::NeedWrite) {
1!
298
      res = waitForRWData(d_socket, false, timeout.tv_sec, timeout.tv_usec);
×
299
      if (res == 0) {
×
300
        throw std::runtime_error("Timeout while writing to TLS connection");
×
301
      }
×
302
      else if (res < 0) {
×
303
        throw std::runtime_error("Error waiting to write to TLS connection");
×
304
      }
×
305
    }
×
306
  }
5✔
307

308
  IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
309
  {
132✔
310
    /* sorry */
311
    (void) fastOpen;
132✔
312
    (void) remote;
132✔
313

314
    int res = SSL_connect(d_conn.get());
132✔
315
    if (res == 1) {
132!
316
      return IOState::Done;
×
317
    }
×
318
    else if (res < 0) {
132!
319
      return convertIORequestToIOState(res);
132✔
320
    }
132✔
321

322
    throw std::runtime_error("Error establishing a TLS connection");
×
323
  }
132✔
324

325
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval &timeout) override
326
  {
4✔
327
    /* sorry */
328
    (void) fastOpen;
4✔
329
    (void) remote;
4✔
330

331
    struct timeval start{0,0};
4✔
332
    struct timeval remainingTime = timeout;
4✔
333
    if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
4!
334
      gettimeofday(&start, nullptr);
4✔
335
    }
4✔
336

337
    int res = 0;
4✔
338
    do {
8✔
339
      res = SSL_connect(d_conn.get());
8✔
340
      if (res < 0) {
8✔
341
        handleIORequest(res, remainingTime);
5✔
342
      }
5✔
343

344
      if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
8!
345
        struct timeval now;
7✔
346
        gettimeofday(&now, nullptr);
7✔
347
        struct timeval elapsed = now - start;
7✔
348
        if (now < start || remainingTime < elapsed) {
7!
349
          throw runtime_error("Timeout while establishing TLS connection");
×
350
        }
×
351
        start = now;
7✔
352
        remainingTime = remainingTime - elapsed;
7✔
353
      }
7✔
354
    }
8✔
355
    while (res != 1);
8✔
356
  }
4✔
357

358
  IOState tryHandshake() override
359
  {
1,088✔
360
    if (isClient()) {
1,088!
361
      /* In client mode, the handshake is initiated by the call to SSL_connect()
362
         done from connect()/tryConnect().
363
         In blocking mode it does not return before the handshake has been finished,
364
         and in non-blocking mode calling SSL_connect() once is enough for SSL_write()
365
         and SSL_read() to transparently continue to negotiate the connection after that
366
         (equivalent to doing SSL_set_connect_state() plus trying to write).
367
      */
368
      return IOState::Done;
×
369
    }
×
370

371
    /* As explained above in the client-mode block, we only need to call SSL_accept() once
372
       for SSL_write() and SSL_read() to transparently continue to negotiate the connection after that.
373
       It is equivalent to calling SSL_set_accept_state() plus trying to read.
374
    */
375
    int res = SSL_accept(d_conn.get());
1,088✔
376
    if (res == 1) {
1,088✔
377
      return IOState::Done;
543✔
378
    }
543✔
379
    else if (res < 0) {
545!
380
      return convertIORequestToIOState(res);
545✔
381
    }
545✔
382

383
    throw std::runtime_error("Error accepting TLS connection");
×
384
  }
1,088✔
385

386
  void doHandshake() override
387
  {
×
388
    if (isClient()) {
×
389
      /* we are a client, nothing to do, see the non-blocking version */
390
      return;
×
391
    }
×
392

393
    int res = 0;
×
394
    do {
×
395
      res = SSL_accept(d_conn.get());
×
396
      if (res < 0) {
×
397
        handleIORequest(res, d_timeout);
×
398
      }
×
399
    }
×
400
    while (res < 0);
×
401

402
    if (res != 1) {
×
403
      throw std::runtime_error("Error accepting TLS connection");
×
404
    }
×
405
  }
×
406

407
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
408
  {
1,600✔
409
    if (isClient() && !d_connected) {
1,600✔
410
      if (d_ktls) {
229!
411
        /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
412
        SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get()));
×
413
      }
×
414
    }
229✔
415

416
    do {
1,600✔
417
      int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
1,600✔
418
      if (res <= 0) {
1,600✔
419
        return convertIORequestToIOState(res);
129✔
420
      }
129✔
421
      else {
1,471✔
422
        pos += static_cast<size_t>(res);
1,471✔
423
      }
1,471✔
424
    }
1,600✔
425
    while (pos < toWrite);
1,600!
426

427
    if (!d_connected) {
1,471✔
428
      d_connected = true;
612✔
429
    }
612✔
430

431
    return IOState::Done;
1,471✔
432
  }
1,600✔
433

434
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
435
  {
3,779✔
436
    do {
3,781✔
437
      int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
3,781✔
438
      if (res <= 0) {
3,781✔
439
        return convertIORequestToIOState(res);
1,573✔
440
      }
1,573✔
441
      else {
2,208✔
442
        pos += static_cast<size_t>(res);
2,208✔
443
        if (allowIncomplete) {
2,208✔
444
          break;
766✔
445
        }
766✔
446
      }
2,208✔
447
    }
3,781✔
448
    while (pos < toRead);
3,779✔
449
    return IOState::Done;
2,206✔
450
  }
3,779✔
451

452
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
453
  {
×
454
    size_t got = 0;
×
455
    struct timeval start = {0, 0};
×
456
    struct timeval remainingTime = totalTimeout;
×
457
    if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
458
      gettimeofday(&start, nullptr);
×
459
    }
×
460

461
    do {
×
462
      int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
×
463
      if (res <= 0) {
×
464
        handleIORequest(res, readTimeout);
×
465
      }
×
466
      else {
×
467
        got += static_cast<size_t>(res);
×
468
        if (allowIncomplete) {
×
469
          break;
×
470
        }
×
471
      }
×
472

473
      if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
474
        struct timeval now;
×
475
        gettimeofday(&now, nullptr);
×
476
        struct timeval elapsed = now - start;
×
477
        if (now < start || remainingTime < elapsed) {
×
478
          throw runtime_error("Timeout while reading data");
×
479
        }
×
480
        start = now;
×
481
        remainingTime = remainingTime - elapsed;
×
482
      }
×
483
    }
×
484
    while (got < bufferSize);
×
485

486
    return got;
×
487
  }
×
488

489
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
490
  {
×
491
    size_t got = 0;
×
492
    do {
×
493
      int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
×
494
      if (res <= 0) {
×
495
        handleIORequest(res, writeTimeout);
×
496
      }
×
497
      else {
×
498
        got += static_cast<size_t>(res);
×
499
      }
×
500
    }
×
501
    while (got < bufferSize);
×
502

503
    return got;
×
504
  }
×
505

506
  bool isUsable() const override
507
  {
3✔
508
    if (!d_conn) {
3!
509
      return false;
×
510
    }
×
511

512
    char buf;
3✔
513
    int res = SSL_peek(d_conn.get(), &buf, sizeof(buf));
3✔
514
    if (res > 0) {
3!
515
      return true;
×
516
    }
×
517
    try {
3✔
518
      convertIORequestToIOState(res);
3✔
519
      return true;
3✔
520
    }
3✔
521
    catch (...) {
3✔
522
      return false;
1✔
523
    }
1✔
524

525
    return false;
×
526
  }
3✔
527

528
  void close() override
529
  {
660✔
530
    if (d_conn) {
660!
531
      SSL_shutdown(d_conn.get());
660✔
532
    }
660✔
533
  }
660✔
534

535
  std::string getServerNameIndication() const override
536
  {
785✔
537
    if (d_conn) {
785!
538
      const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
785✔
539
      if (value) {
785✔
540
        return std::string(value);
782✔
541
      }
782✔
542
    }
785✔
543
    return std::string();
3✔
544
  }
785✔
545

546
  std::vector<uint8_t> getNextProtocol() const override
547
  {
194✔
548
    std::vector<uint8_t> result;
194✔
549
    if (!d_conn) {
194!
550
      return result;
×
551
    }
×
552

553
    const unsigned char* alpn = nullptr;
194✔
554
    unsigned int alpnLen  = 0;
194✔
555
#ifdef HAVE_SSL_GET0_ALPN_SELECTED
194✔
556
    if (alpn == nullptr) {
194!
557
      SSL_get0_alpn_selected(d_conn.get(), &alpn, &alpnLen);
194✔
558
    }
194✔
559
#endif /* HAVE_SSL_GET0_ALPN_SELECTED */
194✔
560
    if (alpn != nullptr && alpnLen > 0) {
194!
561
      result.insert(result.end(), alpn, alpn + alpnLen);
171✔
562
    }
171✔
563
    return result;
194✔
564
  }
194✔
565

566
  LibsslTLSVersion getTLSVersion() const override
567
  {
791✔
568
    auto proto = SSL_version(d_conn.get());
791✔
569
    switch (proto) {
791✔
570
    case TLS1_VERSION:
×
571
      return LibsslTLSVersion::TLS10;
×
572
    case TLS1_1_VERSION:
×
573
      return LibsslTLSVersion::TLS11;
×
574
    case TLS1_2_VERSION:
11✔
575
      return LibsslTLSVersion::TLS12;
11✔
576
#ifdef TLS1_3_VERSION
×
577
    case TLS1_3_VERSION:
780✔
578
      return LibsslTLSVersion::TLS13;
780✔
579
#endif /* TLS1_3_VERSION */
×
580
    default:
×
581
      return LibsslTLSVersion::Unknown;
×
582
    }
791✔
583
  }
791✔
584

585
  bool hasSessionBeenResumed() const override
586
  {
605✔
587
    if (d_conn) {
605!
588
      return SSL_session_reused(d_conn.get()) != 0;
605✔
589
    }
605✔
590
    return false;
×
591
  }
605✔
592

593
  std::vector<std::unique_ptr<TLSSession>> getSessions() override
594
  {
79✔
595
    return std::move(d_tlsSessions);
79✔
596
  }
79✔
597

598
  void setSession(std::unique_ptr<TLSSession>& session) override
599
  {
81✔
600
    auto sess = dynamic_cast<OpenSSLSession*>(session.get());
81✔
601
    if (!sess) {
81!
602
      throw std::runtime_error("Unable to convert OpenSSL session");
×
603
    }
×
604

605
    auto native = sess->getNative();
81✔
606
    auto ret = SSL_set_session(d_conn.get(), native.get());
81✔
607
    if (ret != 1) {
81!
608
      throw std::runtime_error("Error setting up session: " + libssl_get_error_string());
×
609
    }
×
610
    session.reset();
81✔
611
  }
81✔
612

613
  void addNewTicket(SSL_SESSION* session)
614
  {
141✔
615
    d_tlsSessions.push_back(std::make_unique<OpenSSLSession>(std::unique_ptr<SSL_SESSION, void (*)(SSL_SESSION*)>(session, SSL_SESSION_free)));
141✔
616
  }
141✔
617

618
  void enableKTLS()
619
  {
×
620
    d_ktls = true;
×
621
  }
×
622

623
  [[nodiscard]] bool isClient() const
624
  {
2,688✔
625
    return d_isClient;
2,688✔
626
  }
2,688✔
627

628
  static void generateConnectionIndexIfNeeded()
629
  {
109✔
630
    auto init = s_initTLSConnIndex.lock();
109✔
631
    if (*init == true) {
109✔
632
      return;
36✔
633
    }
36✔
634

635
    /* not initialized yet */
636
    s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
73✔
637
    if (s_tlsConnIndex == -1) {
73!
638
      throw std::runtime_error("Error getting an index for TLS connection data");
×
639
    }
×
640

641
    *init = true;
73✔
642
  }
73✔
643

644
  static int getConnectionIndex()
645
  {
839✔
646
    return s_tlsConnIndex;
839✔
647
  }
839✔
648

649
private:
650
  static LockGuarded<bool> s_initTLSConnIndex;
651
  static int s_tlsConnIndex;
652
  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
653
  const std::shared_ptr<const OpenSSLTLSIOCtx> d_tlsCtx; // we need to hold a reference to this to make sure that the context exists for as long as the connection, even if a reload happens in the meantime
654
  std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
655
  const std::string d_hostname;
656
  const timeval d_timeout;
657
  bool d_connected{false};
658
  bool d_ktls{false};
659
  const bool d_isClient{false};
660
};
661

662
LockGuarded<bool> OpenSSLTLSConnection::s_initTLSConnIndex{false};
663
int OpenSSLTLSConnection::s_tlsConnIndex{-1};
664

665
class OpenSSLTLSIOCtx: public TLSCtx, public std::enable_shared_from_this<OpenSSLTLSIOCtx>
666
{
667
  struct Private
668
  {
669
    explicit Private() = default;
670
  };
671

672
public:
673
  static std::shared_ptr<OpenSSLTLSIOCtx> createServerSideContext(TLSFrontend& frontend)
674
  {
67✔
675
    return std::make_shared<OpenSSLTLSIOCtx>(frontend, Private());
67✔
676
  }
67✔
677

678
  static std::shared_ptr<OpenSSLTLSIOCtx> createClientSideContext(const TLSContextParameters& params)
679
  {
42✔
680
    return std::make_shared<OpenSSLTLSIOCtx>(params, Private());
42✔
681
  }
42✔
682

683
  /* server side context */
684
  OpenSSLTLSIOCtx(TLSFrontend& frontend, [[maybe_unused]] Private priv): d_alpnProtos(getALPNVector(frontend.d_alpn, false)), d_feContext(std::make_unique<OpenSSLFrontendContext>(frontend.d_addr, frontend.d_tlsConfig))
67✔
685
  {
67✔
686
    OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
67✔
687

688
    d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;
67✔
689

690
    for (auto& entry : d_feContext->d_sniMap) {
136✔
691
      auto* ctx = entry.second.get();
136✔
692
      if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) {
136!
693
        /* use our own ticket keys handler so we can rotate them */
694
#if OPENSSL_VERSION_MAJOR >= 3
132✔
695
        SSL_CTX_set_tlsext_ticket_key_evp_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb);
132✔
696
#else
697
        SSL_CTX_set_tlsext_ticket_key_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb);
698
#endif
699
        libssl_set_ticket_key_callback_data(ctx, d_feContext.get());
132✔
700
      }
132✔
701

702
#ifndef DISABLE_OCSP_STAPLING
136✔
703
      if (!d_feContext->d_ocspResponses.empty()) {
136✔
704
        SSL_CTX_set_tlsext_status_cb(ctx, &OpenSSLTLSIOCtx::ocspStaplingCb);
10✔
705
        SSL_CTX_set_tlsext_status_arg(ctx, &d_feContext->d_ocspResponses);
10✔
706
      }
10✔
707
#endif /* DISABLE_OCSP_STAPLING */
136✔
708

709
      if (frontend.d_tlsConfig.d_readAhead) {
136!
710
        SSL_CTX_set_read_ahead(ctx, 1);
136✔
711
      }
136✔
712

713
      libssl_set_error_counters_callback(*ctx, &frontend.d_tlsCounters);
136✔
714

715
      libssl_set_alpn_select_callback(ctx, alpnServerSelectCallback, this);
136✔
716

717
      if (!frontend.d_tlsConfig.d_keyLogFile.empty()) {
136!
718
        d_feContext->d_keyLogFile = libssl_set_key_log_file(ctx, frontend.d_tlsConfig.d_keyLogFile);
×
719
      }
×
720
    }
136✔
721

722
    try {
67✔
723
      if (frontend.d_tlsConfig.d_ticketKeyFile.empty()) {
67!
724
        handleTicketsKeyRotation(time(nullptr));
67✔
725
      }
67✔
726
      else {
×
727
        OpenSSLTLSIOCtx::loadTicketsKeys(frontend.d_tlsConfig.d_ticketKeyFile);
×
728
      }
×
729
    }
67✔
730
    catch (const std::exception& e) {
67✔
731
      throw;
×
732
    }
×
733
  }
67✔
734

735
  /* client side context */
736
  OpenSSLTLSIOCtx(const TLSContextParameters& params, [[maybe_unused]] Private priv)
737
  {
42✔
738
    int sslOptions =
42✔
739
      SSL_OP_NO_SSLv2 |
42✔
740
      SSL_OP_NO_SSLv3 |
42✔
741
      SSL_OP_NO_COMPRESSION |
42✔
742
      SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
42✔
743
      SSL_OP_SINGLE_DH_USE |
42✔
744
      SSL_OP_SINGLE_ECDH_USE |
42✔
745
#ifdef SSL_OP_IGNORE_UNEXPECTED_EOF
42✔
746
      SSL_OP_IGNORE_UNEXPECTED_EOF |
42✔
747
#endif
42✔
748
      SSL_OP_CIPHER_SERVER_PREFERENCE;
42✔
749
    if (!params.d_enableRenegotiation) {
42!
750
#ifdef SSL_OP_NO_RENEGOTIATION
42✔
751
      sslOptions |= SSL_OP_NO_RENEGOTIATION;
42✔
752
#elif defined(SSL_OP_NO_CLIENT_RENEGOTIATION)
753
      sslOptions |= SSL_OP_NO_CLIENT_RENEGOTIATION;
754
#endif
755
    }
42✔
756

757
    if (params.d_ktls) {
42!
758
#ifdef SSL_OP_ENABLE_KTLS
×
759
      sslOptions |= SSL_OP_ENABLE_KTLS;
×
760
      d_ktls = true;
×
761
#endif /* SSL_OP_ENABLE_KTLS */
×
762
    }
×
763

764
    registerOpenSSLUser();
42✔
765

766
    OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
42✔
767

768
#ifdef HAVE_TLS_CLIENT_METHOD
769
    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
770
#else
771
    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
42✔
772
#endif
42✔
773
    if (!d_tlsCtx) {
42!
774
      ERR_print_errors_fp(stderr);
×
775
      throw std::runtime_error("Error creating TLS context");
×
776
    }
×
777

778
    SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
42✔
779
#if defined(SSL_CTX_set_ecdh_auto)
42✔
780
    SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
42✔
781
#endif
42✔
782

783
    if (!params.d_ciphers.empty()) {
42!
784
      if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) {
×
785
        ERR_print_errors_fp(stderr);
×
786
        throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context");
×
787
      }
×
788
    }
×
789
#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
42✔
790
    if (!params.d_ciphers13.empty()) {
42!
791
      if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) {
×
792
        ERR_print_errors_fp(stderr);
×
793
        throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context");
×
794
      }
×
795
    }
×
796
#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
42✔
797

798
    if (params.d_validateCertificates) {
42✔
799
      if (params.d_caStore.empty())  {
33!
800
        if (SSL_CTX_set_default_verify_paths(d_tlsCtx.get()) != 1) {
×
801
          throw std::runtime_error("Error adding the system's default trusted CAs");
×
802
        }
×
803
      } else {
33✔
804
        if (SSL_CTX_load_verify_locations(d_tlsCtx.get(), params.d_caStore.c_str(), nullptr) != 1) {
33!
805
          throw std::runtime_error("Error adding the trusted CAs file " + params.d_caStore);
×
806
        }
×
807
      }
33✔
808

809
      SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr);
33✔
810
#if (OPENSSL_VERSION_NUMBER < 0x10002000L)
811
      warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
812
#endif
813
    }
33✔
814

815
    /* we need to set SSL_SESS_CACHE_CLIENT for the "new ticket" callback (below) to be called,
816
       but we don't want OpenSSL to cache the session itself so we set SSL_SESS_CACHE_NO_INTERNAL_STORE as well */
817
    SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE);
42✔
818
    SSL_CTX_sess_set_new_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::newTicketFromServerCb);
42✔
819

820
    if (!params.d_keyLogFile.empty()) {
42✔
821
      d_keyLogFile = libssl_set_key_log_file(d_tlsCtx.get(), params.d_keyLogFile);
4✔
822
    }
4✔
823

824
    libssl_set_alpn_protos(d_tlsCtx.get(), getALPNVector(params.d_alpn, true));
42✔
825

826
#ifdef SSL_MODE_RELEASE_BUFFERS
42✔
827
    if (params.d_releaseBuffers) {
42!
828
      SSL_CTX_set_mode(d_tlsCtx.get(), SSL_MODE_RELEASE_BUFFERS);
42✔
829
    }
42✔
830
#endif
42✔
831
  }
42✔
832

833
  OpenSSLTLSIOCtx(const OpenSSLTLSIOCtx&) = delete;
834
  OpenSSLTLSIOCtx(OpenSSLTLSIOCtx&&) = delete;
835
  OpenSSLTLSIOCtx& operator=(const OpenSSLTLSIOCtx&) = delete;
836
  OpenSSLTLSIOCtx& operator=(OpenSSLTLSIOCtx&&) = delete;
837

838
  ~OpenSSLTLSIOCtx() override
839
  {
15✔
840
    d_tlsCtx.reset();
15✔
841
    unregisterOpenSSLUser();
15✔
842
  }
15✔
843

844
#if OPENSSL_VERSION_MAJOR >= 3
845
  static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx, int enc)
846
#else
847
  static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx, int enc)
848
#endif
849
  {
849✔
850
    auto* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
849✔
851
    if (ctx == nullptr) {
849!
852
      return -1;
×
853
    }
×
854

855
    int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
849✔
856
    if (enc == 0) {
849✔
857
      if (ret == 0 || ret == 2) {
241✔
858
        auto* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::getConnectionIndex()));
14✔
859
        if (conn != nullptr) {
14!
860
          if (ret == 0) {
14✔
861
            conn->setUnknownTicketKey();
6✔
862
          }
6✔
863
          else if (ret == 2) {
8!
864
            conn->setResumedFromInactiveTicketKey();
8✔
865
          }
8✔
866
        }
14✔
867
      }
14✔
868
    }
241✔
869

870
    return ret;
849✔
871
  }
849✔
872

873
#ifndef DISABLE_OCSP_STAPLING
874
  static int ocspStaplingCb(SSL* ssl, void* arg)
875
  {
4✔
876
    if (ssl == nullptr || arg == nullptr) {
4!
877
      return SSL_TLSEXT_ERR_NOACK;
×
878
    }
×
879
    const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
4✔
880
    return libssl_ocsp_stapling_callback(ssl, *ocspMap);
4✔
881
  }
4✔
882
#endif /* DISABLE_OCSP_STAPLING */
883

884
  static int newTicketFromServerCb(SSL* ssl, SSL_SESSION* session)
885
  {
141✔
886
    OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(ssl, OpenSSLTLSConnection::getConnectionIndex()));
141✔
887
    if (session == nullptr || conn == nullptr) {
141!
888
      return 0;
×
889
    }
×
890

891
    conn->addNewTicket(session);
141✔
892
    return 1;
141✔
893
  }
141✔
894

895
  SSL_CTX* getOpenSSLContext() const
896
  {
684✔
897
    if (d_feContext) {
684✔
898
      return d_feContext->d_tlsCtx.get();
545✔
899
    }
545✔
900
    return d_tlsCtx.get();
139✔
901
  }
684✔
902

903
  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
904
  {
545✔
905
    handleTicketsKeyRotation(now);
545✔
906

907
    return std::make_unique<OpenSSLTLSConnection>(socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
545✔
908
  }
545✔
909

910
  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
911
  {
139✔
912
    auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
139✔
913
    if (d_ktls) {
139!
914
      conn->enableKTLS();
×
915
    }
×
916
    return conn;
139✔
917
  }
139✔
918

919
  void rotateTicketsKey(time_t now) override
920
  {
111✔
921
    d_feContext->d_ticketKeys.rotateTicketsKey(now);
111✔
922

923
    if (d_ticketsKeyRotationDelay > 0) {
111!
924
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
111✔
925
    }
111✔
926
  }
111✔
927

928
  void loadTicketsKeys(const std::string& keyFile) final
929
  {
12✔
930
    d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
12✔
931

932
    if (d_ticketsKeyRotationDelay > 0) {
12!
933
      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
12✔
934
    }
12✔
935
  }
12✔
936

937
  void loadTicketsKey(const std::string& key) final
938
  {
1✔
939
    d_feContext->d_ticketKeys.loadTicketsKey(key);
1✔
940

941
    if (d_ticketsKeyRotationDelay > 0) {
1!
942
      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
1✔
943
    }
1✔
944
  }
1✔
945

946
  size_t getTicketsKeysCount() override
947
  {
12✔
948
    return d_feContext->d_ticketKeys.getKeysCount();
12✔
949
  }
12✔
950

951
  std::string getName() const override
952
  {
4✔
953
    return "openssl";
4✔
954
  }
4✔
955

956
  bool isServerContext() const
957
  {
×
958
    return d_feContext != nullptr;
×
959
  }
×
960

961
private:
962
  /* called in a client context, if the client advertised more than one ALPN value and the server returned more than one as well, to select the one to use. */
963
  static int alpnServerSelectCallback(SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
964
  {
177✔
965
    if (!arg) {
177!
966
      return SSL_TLSEXT_ERR_ALERT_WARNING;
×
967
    }
×
968
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
969
    OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
177✔
970

971
    const pdns::views::UnsignedCharView inView(in, inlen);
177✔
972
    // Server preference algorithm as per RFC 7301 section 3.2
973
    for (const auto& tentative : obj->d_alpnProtos) {
181✔
974
      size_t pos = 0;
181✔
975
      while (pos < inView.size()) {
257✔
976
        size_t protoLen = inView.at(pos);
247✔
977
        pos++;
247✔
978
        if (protoLen > (inlen - pos)) {
247!
979
          /* something is very wrong */
980
          return SSL_TLSEXT_ERR_ALERT_WARNING;
×
981
        }
×
982

983
        if (tentative.size() == protoLen && memcmp(&inView.at(pos), tentative.data(), tentative.size()) == 0) {
247!
984
          *out = &inView.at(pos);
171✔
985
          *outlen = protoLen;
171✔
986
          return SSL_TLSEXT_ERR_OK;
171✔
987
        }
171✔
988
        pos += protoLen;
76✔
989
      }
76✔
990
    }
181✔
991

992
    return SSL_TLSEXT_ERR_NOACK;
6✔
993
  }
177✔
994

995
  const std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
996
  std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
997
  std::unique_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
998
  pdns::UniqueFilePtr d_keyLogFile{nullptr};
999
  bool d_ktls{false};
1000
};
1001

1002
#endif /* HAVE_LIBSSL */
1003

1004
#ifdef HAVE_GNUTLS
1005
#include <gnutls/gnutls.h>
1006
#include <gnutls/x509.h>
1007

1008
static void safe_memory_lock([[maybe_unused]] void* data, [[maybe_unused]] size_t size)
1009
{
9✔
1010
#ifdef HAVE_LIBSODIUM
9✔
1011
  sodium_mlock(data, size);
9✔
1012
#endif
9✔
1013
}
9✔
1014

1015
static void safe_memory_release(void* data, size_t size)
1016
{
43✔
1017
#ifdef HAVE_LIBSODIUM
43✔
1018
  sodium_munlock(data, size);
43✔
1019
#elif defined(HAVE_EXPLICIT_BZERO)
1020
  explicit_bzero(data, size);
1021
#elif defined(HAVE_EXPLICIT_MEMSET)
1022
  explicit_memset(data, 0, size);
1023
#elif defined(HAVE_GNUTLS_MEMSET)
1024
  gnutls_memset(data, 0, size);
1025
#else
1026
  /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
1027
  volatile unsigned int volatile_zero_idx = 0;
1028
  volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
1029

1030
  if (size == 0)
1031
    return;
1032

1033
  do {
1034
    memset(data, 0, size);
1035
  } while (p[volatile_zero_idx] != 0);
1036
#endif
1037
}
43✔
1038

1039
class GnuTLSTicketsKey
1040
{
1041
public:
1042
  GnuTLSTicketsKey()
1043
  {
6✔
1044
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
6!
1045
      throw std::runtime_error("Error generating tickets key for TLS context");
1046
    }
1047

1048
    safe_memory_lock(d_key.data, d_key.size);
6✔
1049
  }
6✔
1050

1051
  GnuTLSTicketsKey(const std::string& key)
1052
  {
1✔
1053
    /* to be sure we are loading the correct amount of data, which
1054
       may change between versions, let's generate a correct key first */
1055
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
1!
1056
      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
1057
    }
1058

1059
    safe_memory_lock(d_key.data, d_key.size);
1✔
1060
    if (key.size() != d_key.size) {
1!
1061
      safe_memory_release(d_key.data, d_key.size);
1062
      gnutls_free(d_key.data);
1063
      d_key.data = nullptr;
1064
      throw std::runtime_error("Invalid GnuTLS ticket key size");
1065
    }
1066
    memcpy(d_key.data, key.data(), key.size());
1✔
1067
  }
1✔
1068
  GnuTLSTicketsKey(std::ifstream& file)
1069
  {
1070
    /* to be sure we are loading the correct amount of data, which
1071
       may change between versions, let's generate a correct key first */
1072
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
×
1073
      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
1074
    }
1075

1076
    safe_memory_lock(d_key.data, d_key.size);
1077

1078
    try {
1079
      file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
1080

1081
      if (file.fail()) {
×
1082
        throw std::runtime_error("Invalid GnuTLS tickets key file");
1083
      }
1084

1085
    }
1086
    catch (const std::exception& e) {
1087
      safe_memory_release(d_key.data, d_key.size);
1088
      gnutls_free(d_key.data);
1089
      d_key.data = nullptr;
1090
      throw;
1091
    }
1092
  }
1093
  [[nodiscard]] std::string content() const
1094
  {
2✔
1095
    std::string result{};
2✔
1096
    if (d_key.data != nullptr && d_key.size > 0) {
2!
1097
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1098
      result.append(reinterpret_cast<const char*>(d_key.data), d_key.size);
2✔
1099
      safe_memory_lock(result.data(), result.size());
2✔
1100
    }
2✔
1101
    return result;
2✔
1102
  }
2✔
1103

1104
  ~GnuTLSTicketsKey()
1105
  {
3✔
1106
    if (d_key.data != nullptr && d_key.size > 0) {
3!
1107
      safe_memory_release(d_key.data, d_key.size);
3✔
1108
    }
3✔
1109
    gnutls_free(d_key.data);
3✔
1110
    d_key.data = nullptr;
3✔
1111
  }
3✔
1112
  const gnutls_datum_t& getKey() const
1113
  {
14✔
1114
    return d_key;
14✔
1115
  }
14✔
1116

1117
private:
1118
  gnutls_datum_t d_key{nullptr, 0};
1119
};
1120

1121
class GnuTLSSession : public TLSSession
1122
{
1123
public:
1124
  GnuTLSSession(gnutls_datum_t& sess): d_sess(sess)
49✔
1125
  {
49✔
1126
    sess.data = nullptr;
49✔
1127
    sess.size = 0;
49✔
1128
  }
49✔
1129

1130
  ~GnuTLSSession() override
1131
  {
38✔
1132
    if (d_sess.data != nullptr && d_sess.size > 0) {
38!
1133
      safe_memory_release(d_sess.data, d_sess.size);
38✔
1134
    }
38✔
1135
    gnutls_free(d_sess.data);
38✔
1136
    d_sess.data = nullptr;
38✔
1137
  }
38✔
1138

1139
  const gnutls_datum_t& getNative()
1140
  {
38✔
1141
    return d_sess;
38✔
1142
  }
38✔
1143

1144
private:
1145
  gnutls_datum_t d_sess{nullptr, 0};
1146
};
1147

1148
class GnuTLSConnection: public TLSConnection
1149
{
1150
public:
1151
  /* server side connection */
1152
  GnuTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_creds(creds), d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
14✔
1153
  {
14✔
1154
    unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
14✔
1155
#ifdef GNUTLS_NO_SIGNAL
14✔
1156
    sslOptions |= GNUTLS_NO_SIGNAL;
14✔
1157
#endif
14✔
1158

1159
    d_socket = socket;
14✔
1160

1161
    gnutls_session_t conn;
14✔
1162
    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
14!
1163
      throw std::runtime_error("Error creating TLS connection");
1164
    }
1165

1166
    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
14✔
1167
    conn = nullptr;
14✔
1168

1169
    if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()) != GNUTLS_E_SUCCESS) {
14!
1170
      throw std::runtime_error("Error setting certificate and key to TLS connection");
1171
    }
1172

1173
    if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
14!
1174
      throw std::runtime_error("Error setting ciphers to TLS connection");
1175
    }
1176

1177
    if (enableTickets && d_ticketsKey) {
14!
1178
      const gnutls_datum_t& key = d_ticketsKey->getKey();
14✔
1179
      if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
14!
1180
        throw std::runtime_error("Error setting the tickets key to TLS connection");
1181
      }
1182
    }
14✔
1183

1184
    gnutls_transport_set_int(d_conn.get(), d_socket);
14✔
1185

1186
    /* timeouts are in milliseconds */
1187
    gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
14✔
1188
    gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
14✔
1189
  }
14✔
1190

1191
  /* client-side connection */
1192
  GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, bool validateCerts): d_creds(creds), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
64✔
1193
  {
64✔
1194
    unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
64✔
1195
#ifdef GNUTLS_NO_SIGNAL
64✔
1196
    sslOptions |= GNUTLS_NO_SIGNAL;
64✔
1197
#endif
64✔
1198

1199
    d_socket = socket;
64✔
1200

1201
    gnutls_session_t conn;
64✔
1202
    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
64!
1203
      throw std::runtime_error("Error creating TLS connection");
1204
    }
1205

1206
    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
64✔
1207
    conn = nullptr;
64✔
1208

1209
    int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get());
64✔
1210
    if (rc != GNUTLS_E_SUCCESS) {
64!
1211
      throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
1212
    }
1213

1214
    rc = gnutls_priority_set(d_conn.get(), priorityCache);
64✔
1215
    if (rc != GNUTLS_E_SUCCESS) {
64!
1216
      throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc)));
1217
    }
1218

1219
    gnutls_transport_set_int(d_conn.get(), d_socket);
64✔
1220

1221
    /* timeouts are in milliseconds */
1222
    gnutls_handshake_set_timeout(d_conn.get(),  timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
64✔
1223
    gnutls_record_set_timeout(d_conn.get(),  timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
64✔
1224

1225
#ifdef HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
64✔
1226
    if (validateCerts && !d_host.empty()) {
64!
1227
      gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
42✔
1228
      rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
42✔
1229
      if (rc != GNUTLS_E_SUCCESS) {
42!
1230
        throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc)));
1231
      }
1232
    }
42✔
1233
#else
1234
    /* no hostname validation for you */
1235
#endif
1236

1237
    /* allow access to our data in the callbacks */
1238
    gnutls_session_set_ptr(d_conn.get(), this);
64✔
1239
    gnutls_handshake_set_hook_function(d_conn.get(), GNUTLS_HANDSHAKE_NEW_SESSION_TICKET, GNUTLS_HOOK_POST, newTicketFromServerCb);
64✔
1240
  }
64✔
1241

1242
  /* The callback prototype changed in 3.4.0. */
1243
#if GNUTLS_VERSION_NUMBER >= 0x030400
1244
  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */, const gnutls_datum_t* /* msg */)
1245
#else
1246
  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */)
1247
#endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */
1248
  {
49✔
1249
    if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) {
49!
1250
      return 0;
1251
    }
1252

1253
    GnuTLSConnection* conn = reinterpret_cast<GnuTLSConnection*>(gnutls_session_get_ptr(session));
49✔
1254
    if (conn == nullptr) {
49!
1255
      return 0;
1256
    }
1257

1258
    gnutls_datum_t sess{nullptr, 0};
49✔
1259
    auto ret = gnutls_session_get_data2(session, &sess);
49✔
1260
    /* GnuTLS returns a 'fake' ticket of 4 bytes set to zero when there is no ticket available */
1261
    if (ret != GNUTLS_E_SUCCESS || sess.size <= 4) {
49!
1262
      throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
1263
    }
1264
    conn->d_tlsSessions.push_back(std::make_unique<GnuTLSSession>(sess));
49✔
1265
    return 0;
49✔
1266
  }
49✔
1267

1268
  IOState tryConnect(bool fastOpen, [[maybe_unused]] const ComboAddress& remote) override
1269
  {
64✔
1270
    int ret = 0;
64✔
1271

1272
    if (fastOpen) {
64!
1273
#ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
1274
      gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast<struct sockaddr*>(reinterpret_cast<const struct sockaddr*>(&remote)), remote.getSocklen(), 0);
1275
#endif
1276
    }
1277

1278
    do {
64✔
1279
      ret = gnutls_handshake(d_conn.get());
64✔
1280
      if (ret == GNUTLS_E_SUCCESS) {
64!
1281
        d_handshakeDone = true;
1282
        return IOState::Done;
1283
      }
1284
      else if (ret == GNUTLS_E_AGAIN) {
64✔
1285
        int direction = gnutls_record_get_direction(d_conn.get());
46✔
1286
        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
46!
1287
      }
46✔
1288
      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
18!
1289
        throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
18✔
1290
      }
18✔
1291
    } while (ret == GNUTLS_E_INTERRUPTED);
64!
1292

1293
    throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1294
  }
64✔
1295

1296
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override
1297
  {
1298
    struct timeval start = {0, 0};
1299
    struct timeval remainingTime = timeout;
1300
    if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
×
1301
      gettimeofday(&start, nullptr);
1302
    }
1303

1304
    IOState state;
1305
    do {
1306
      state = tryConnect(fastOpen, remote);
1307
      if (state == IOState::Done) {
×
1308
        return;
1309
      }
1310
      else if (state == IOState::NeedRead) {
×
1311
        int result = waitForData(d_socket, remainingTime.tv_sec, remainingTime.tv_usec);
1312
        if (result <= 0) {
×
1313
          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
1314
        }
1315
      }
1316
      else if (state == IOState::NeedWrite) {
×
1317
        int result = waitForRWData(d_socket, false, remainingTime.tv_sec, remainingTime.tv_usec);
1318
        if (result <= 0) {
×
1319
          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
1320
        }
1321
      }
1322

1323
      if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
×
1324
        struct timeval now;
1325
        gettimeofday(&now, nullptr);
1326
        struct timeval elapsed = now - start;
1327
        if (now < start || remainingTime < elapsed) {
×
1328
          throw runtime_error("Timeout while establishing TLS connection");
1329
        }
1330
        start = now;
1331
        remainingTime = remainingTime - elapsed;
1332
      }
1333
    }
1334
    while (state != IOState::Done);
×
1335
  }
1336

1337
  void doHandshake() override
1338
  {
1339
    int ret = 0;
1340
    do {
1341
      ret = gnutls_handshake(d_conn.get());
1342
      if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
×
1343
        if (d_client) {
×
1344
          throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1345
        }
1346
        else {
1347
          throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
1348
        }
1349
      }
1350
    }
1351
    while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED);
×
1352

1353
    d_handshakeDone = true;
1354
  }
1355

1356
  IOState tryHandshake() override
1357
  {
120✔
1358
    int ret = 0;
120✔
1359

1360
    do {
120✔
1361
      ret = gnutls_handshake(d_conn.get());
120✔
1362
      if (ret == GNUTLS_E_SUCCESS) {
120✔
1363
        d_handshakeDone = true;
54✔
1364
        return IOState::Done;
54✔
1365
      }
54✔
1366
      else if (ret == GNUTLS_E_AGAIN) {
66✔
1367
        int direction = gnutls_record_get_direction(d_conn.get());
60✔
1368
        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
60!
1369
      }
60✔
1370
      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
6!
1371
        if (d_client) {
6!
1372
          std::string error;
6✔
1373
#ifdef HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS
6✔
1374
          if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) {
6!
1375
            gnutls_datum_t out;
6✔
1376
            if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn.get()), gnutls_certificate_type_get(d_conn.get()), &out, 0) == 0) {
6!
1377
              error = " (" + std::string(reinterpret_cast<const char*>(out.data)) + ")";
6✔
1378
              gnutls_free(out.data);
6✔
1379
            }
6✔
1380
          }
6✔
1381
#endif /* HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS */
6✔
1382
          throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)) + error);
6✔
1383
        }
6✔
1384
        else {
1385
          throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1386
        }
1387
      }
6✔
1388
    } while (ret == GNUTLS_E_INTERRUPTED);
120!
1389

1390
    if (d_client) {
×
1391
      throw std::runtime_error("Error establishinging a new connection: " + std::string(gnutls_strerror(ret)));
1392
    }
1393
    else {
1394
      throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
1395
    }
1396
  }
1397

1398
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
1399
  {
298✔
1400
    if (!d_handshakeDone) {
298✔
1401
      /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1402
         we need to keep calling gnutls_handshake() until the handshake has been finished. */
1403
      auto state = tryHandshake();
92✔
1404
      if (state != IOState::Done) {
92✔
1405
        return state;
46✔
1406
      }
46✔
1407
    }
92✔
1408

1409
    do {
252✔
1410
      ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
252✔
1411
      if (res == 0) {
252!
1412
        throw std::runtime_error("Error writing to TLS connection");
1413
      }
1414
      else if (res > 0) {
252✔
1415
        pos += static_cast<size_t>(res);
246✔
1416
      }
246✔
1417
      else if (res < 0) {
6!
1418
        if (gnutls_error_is_fatal(res)) {
×
1419
          throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
1420
        }
1421
        else if (res == GNUTLS_E_AGAIN) {
×
1422
          return IOState::NeedWrite;
1423
        }
1424
        vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1425
      }
1426
    }
252✔
1427
    while (pos < toWrite);
252!
1428
    return IOState::Done;
252✔
1429
  }
252✔
1430

1431
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
1432
  {
502✔
1433
    if (!d_handshakeDone) {
502!
1434
      /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1435
         we need to keep calling gnutls_handshake() until the handshake has been finished. */
1436
      auto state = tryHandshake();
1437
      if (state != IOState::Done) {
×
1438
        return state;
1439
      }
1440
    }
1441

1442
    do {
502✔
1443
      ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
502✔
1444
      if (res == 0) {
502✔
1445
        throw std::runtime_error("EOF while reading from TLS connection");
3✔
1446
      }
3✔
1447
      else if (res > 0) {
499✔
1448
        pos += static_cast<size_t>(res);
346✔
1449
        if (allowIncomplete) {
346✔
1450
          break;
70✔
1451
        }
70✔
1452
      }
346✔
1453
      else if (res < 0) {
153!
1454
        if (gnutls_error_is_fatal(res)) {
153✔
1455
          throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
33✔
1456
        }
33✔
1457
        else if (res == GNUTLS_E_AGAIN) {
120!
1458
          return IOState::NeedRead;
120✔
1459
        }
120✔
1460
        vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1461
      }
1462
    }
502✔
1463
    while (pos < toRead);
502!
1464
    return IOState::Done;
346✔
1465
  }
502✔
1466

1467
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
1468
  {
1469
    size_t got = 0;
1470
    struct timeval start{0,0};
1471
    struct timeval  remainingTime = totalTimeout;
1472
    if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
1473
      gettimeofday(&start, nullptr);
1474
    }
1475

1476
    do {
1477
      ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
1478
      if (res == 0) {
×
1479
        throw std::runtime_error("EOF while reading from TLS connection");
1480
      }
1481
      else if (res > 0) {
×
1482
        got += static_cast<size_t>(res);
1483
        if (allowIncomplete) {
×
1484
          break;
1485
        }
1486
      }
1487
      else if (res < 0) {
×
1488
        if (gnutls_error_is_fatal(res)) {
×
1489
          throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
1490
        }
1491
        else if (res == GNUTLS_E_AGAIN) {
×
1492
          int result = waitForData(d_socket, readTimeout.tv_sec, readTimeout.tv_usec);
1493
          if (result <= 0) {
×
1494
            throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
1495
          }
1496
        }
1497
        else {
1498
          vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
×
1499
        }
1500
      }
1501

1502
      if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
1503
        struct timeval now;
1504
        gettimeofday(&now, nullptr);
1505
        struct timeval elapsed = now - start;
1506
        if (now < start || remainingTime < elapsed) {
×
1507
          throw runtime_error("Timeout while reading data");
1508
        }
1509
        start = now;
1510
        remainingTime = remainingTime - elapsed;
1511
      }
1512
    }
1513
    while (got < bufferSize);
×
1514

1515
    return got;
1516
  }
1517

1518
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
1519
  {
1520
    size_t got = 0;
1521

1522
    do {
1523
      ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
1524
      if (res == 0) {
×
1525
        throw std::runtime_error("Error writing to TLS connection");
1526
      }
1527
      else if (res > 0) {
×
1528
        got += static_cast<size_t>(res);
1529
      }
1530
      else if (res < 0) {
×
1531
        if (gnutls_error_is_fatal(res)) {
×
1532
          throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
1533
        }
1534
        else if (res == GNUTLS_E_AGAIN) {
×
1535
          int result = waitForRWData(d_socket, false, writeTimeout.tv_sec, writeTimeout.tv_usec);
1536
          if (result <= 0) {
×
1537
            throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
1538
          }
1539
        }
1540
        else {
1541
          vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1542
        }
1543
      }
1544
    }
1545
    while (got < bufferSize);
×
1546

1547
    return got;
1548
  }
1549

1550
  bool isUsable() const override
1551
  {
2✔
1552
    if (!d_conn) {
2!
1553
      return false;
1554
    }
1555

1556
    /* as far as I can tell we can't peek so we cannot do better */
1557
    return isTCPSocketUsable(d_socket);
2✔
1558
  }
2✔
1559

1560
  std::string getServerNameIndication() const override
1561
  {
112✔
1562
    if (d_conn) {
112!
1563
      unsigned int type;
112✔
1564
      size_t name_len = 256;
112✔
1565
      std::string sni;
112✔
1566
      sni.resize(name_len);
112✔
1567

1568
      int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
112✔
1569
      if (res == GNUTLS_E_SUCCESS) {
112!
1570
        sni.resize(name_len);
112✔
1571
        return sni;
112✔
1572
      }
112✔
1573
    }
112✔
1574
    return std::string();
1575
  }
112✔
1576

1577
  std::vector<uint8_t> getNextProtocol() const override
1578
  {
1579
    std::vector<uint8_t> result;
1580
    if (!d_conn) {
×
1581
      return result;
1582
    }
1583
    gnutls_datum_t next;
1584
    if (gnutls_alpn_get_selected_protocol(d_conn.get(), &next) != GNUTLS_E_SUCCESS) {
×
1585
      return result;
1586
    }
1587
    result.insert(result.end(), next.data, next.data + next.size);
1588
    return result;
1589
  }
1590

1591
  LibsslTLSVersion getTLSVersion() const override
1592
  {
112✔
1593
    auto proto = gnutls_protocol_get_version(d_conn.get());
112✔
1594
    switch (proto) {
112✔
1595
    case GNUTLS_TLS1_0:
112!
1596
      return LibsslTLSVersion::TLS10;
1597
    case GNUTLS_TLS1_1:
112!
1598
      return LibsslTLSVersion::TLS11;
1599
    case GNUTLS_TLS1_2:
3✔
1600
      return LibsslTLSVersion::TLS12;
3✔
1601
#if GNUTLS_VERSION_NUMBER >= 0x030603
1602
    case GNUTLS_TLS1_3:
109✔
1603
      return LibsslTLSVersion::TLS13;
109✔
1604
#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
1605
    default:
112!
1606
      return LibsslTLSVersion::Unknown;
1607
    }
112✔
1608
  }
112✔
1609

1610
  bool hasSessionBeenResumed() const override
1611
  {
46✔
1612
    if (d_conn) {
46!
1613
      return gnutls_session_is_resumed(d_conn.get()) != 0;
46✔
1614
    }
46✔
1615
    return false;
1616
  }
46✔
1617

1618
  std::vector<std::unique_ptr<TLSSession>> getSessions() override
1619
  {
36✔
1620
    return std::move(d_tlsSessions);
36✔
1621
  }
36✔
1622

1623
  void setSession(std::unique_ptr<TLSSession>& session) override
1624
  {
38✔
1625
    auto sess = dynamic_cast<GnuTLSSession*>(session.get());
38✔
1626
    if (!sess) {
38!
1627
      throw std::runtime_error("Unable to convert GnuTLS session");
1628
    }
1629

1630
    auto native = sess->getNative();
38✔
1631
    auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size);
38✔
1632
    if (ret != GNUTLS_E_SUCCESS) {
38!
1633
      throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret)));
1634
    }
1635
    session.reset();
38✔
1636
  }
38✔
1637

1638
  void close() override
1639
  {
72✔
1640
    if (d_conn) {
72!
1641
      gnutls_bye(d_conn.get(), GNUTLS_SHUT_RDWR);
72✔
1642
    }
72✔
1643
  }
72✔
1644

1645
  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos)
1646
  {
78✔
1647
    std::vector<gnutls_datum_t> values;
78✔
1648
    values.reserve(protos.size());
78✔
1649
    for (const auto& proto : protos) {
78✔
1650
      gnutls_datum_t value;
78✔
1651
      value.data = const_cast<uint8_t*>(proto.data());
78✔
1652
      value.size = proto.size();
78✔
1653
      values.push_back(value);
78✔
1654
    }
78✔
1655
    unsigned int flags = 0;
78✔
1656
#if GNUTLS_VERSION_NUMBER >= 0x030500
78✔
1657
    flags |= GNUTLS_ALPN_MANDATORY;
78✔
1658
#elif defined(GNUTLS_ALPN_MAND)
1659
    flags |= GNUTLS_ALPN_MAND;
1660
#endif
1661
    return gnutls_alpn_set_protocols(d_conn.get(), values.data(), values.size(), flags);
78✔
1662
  }
78✔
1663

1664
  std::vector<int> getAsyncFDs() override
1665
  {
14✔
1666
    return {};
14✔
1667
  }
14✔
1668

1669
private:
1670
  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1671
  std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
1672
  std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
1673
  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
1674
  std::string d_host;
1675
  const bool d_client{false};
1676
  bool d_handshakeDone{false};
1677
};
1678

1679
class GnuTLSIOCtx: public TLSCtx
1680
{
1681
public:
1682
  /* server side context */
1683
  GnuTLSIOCtx(TLSFrontend& frontend): d_protos(getALPNVector(frontend.d_alpn, false)), d_enableTickets(frontend.d_tlsConfig.d_enableTickets)
6✔
1684
  {
6✔
1685
    int rc = 0;
6✔
1686
    d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;
6✔
1687

1688
    gnutls_certificate_credentials_t creds;
6✔
1689
    rc = gnutls_certificate_allocate_credentials(&creds);
6✔
1690
    if (rc != GNUTLS_E_SUCCESS) {
6!
1691
      throw std::runtime_error("Error allocating credentials for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1692
    }
1693

1694
    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
6✔
1695
    creds = nullptr;
6✔
1696

1697
    for (const auto& pair : frontend.d_tlsConfig.d_certKeyPairs) {
6✔
1698
      rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.d_cert.c_str(), pair.d_key->c_str(), GNUTLS_X509_FMT_PEM);
6✔
1699
      if (rc != GNUTLS_E_SUCCESS) {
6!
1700
        throw std::runtime_error("Error loading certificate ('" + pair.d_cert + "') and key ('" + pair.d_key.value() + "') for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1701
      }
1702
    }
6✔
1703

1704
#ifndef DISABLE_OCSP_STAPLING
6✔
1705
    size_t count = 0;
6✔
1706
    for (const auto& file : frontend.d_tlsConfig.d_ocspFiles) {
6✔
1707
      rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
3✔
1708
      if (rc != GNUTLS_E_SUCCESS) {
3✔
1709
        warnlog("Error loading OCSP response from file '%s' for certificate ('%s') and key ('%s') for TLS context on %s: %s", file, frontend.d_tlsConfig.d_certKeyPairs.at(count).d_cert, frontend.d_tlsConfig.d_certKeyPairs.at(count).d_key.value(), frontend.d_addr.toStringWithPort(), gnutls_strerror(rc));
1✔
1710
      }
1✔
1711
      ++count;
3✔
1712
    }
3✔
1713
#endif /* DISABLE_OCSP_STAPLING */
6✔
1714

1715
#if GNUTLS_VERSION_NUMBER >= 0x030600
6✔
1716
    rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
6✔
1717
    if (rc != GNUTLS_E_SUCCESS) {
6!
1718
      throw std::runtime_error("Error setting DH params for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1719
    }
1720
#endif
6✔
1721

1722
    rc = gnutls_priority_init(&d_priorityCache, frontend.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : frontend.d_tlsConfig.d_ciphers.c_str(), nullptr);
6!
1723
    if (rc != GNUTLS_E_SUCCESS) {
6!
1724
      throw std::runtime_error("Error setting up TLS cipher preferences to '" + frontend.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + frontend.d_addr.toStringWithPort());
1725
    }
1726

1727
    try {
6✔
1728
      if (frontend.d_tlsConfig.d_ticketKeyFile.empty()) {
6!
1729
        handleTicketsKeyRotation(time(nullptr));
6✔
1730
      }
6✔
1731
      else {
1732
        GnuTLSIOCtx::loadTicketsKeys(frontend.d_tlsConfig.d_ticketKeyFile);
1733
      }
1734
    }
6✔
1735
    catch(const std::runtime_error& e) {
6✔
1736
      throw std::runtime_error("Error generating tickets key for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + e.what());
1737
    }
1738
  }
6✔
1739

1740
  /* client side context */
1741
  GnuTLSIOCtx(const TLSContextParameters& params): d_protos(getALPNVector(params.d_alpn, true)), d_contextParameters(std::make_unique<TLSContextParameters>(params)), d_validateCerts(params.d_validateCertificates)
18✔
1742
  {
18✔
1743
    int rc = 0;
18✔
1744

1745
    gnutls_certificate_credentials_t creds;
18✔
1746
    rc = gnutls_certificate_allocate_credentials(&creds);
18✔
1747
    if (rc != GNUTLS_E_SUCCESS) {
18!
1748
      throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1749
    }
1750

1751
    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
18✔
1752
    creds = nullptr;
18✔
1753

1754
    if (params.d_validateCertificates) {
18✔
1755
      if (params.d_caStore.empty()) {
14!
1756
#if GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703
1757
        /* see https://gitlab.com/gnutls/gnutls/-/issues/1277 */
1758
        std::cerr<<"Warning: GnuTLS 3.7.0 - 3.7.2 have a memory leak when validating server certificates in some configurations (PKCS11 support enabled, and a default PKCS11 trust store), please consider upgrading GnuTLS, using the OpenSSL provider for outgoing connections, or explicitly setting a CA store"<<std::endl;
1759
#endif /* GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703 */
1760
        rc = gnutls_certificate_set_x509_system_trust(d_creds.get());
1761
        if (rc < 0) {
×
1762
          throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1763
        }
1764
      }
1765
      else {
14✔
1766
        rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM);
14✔
1767
        if (rc < 0) {
14!
1768
          throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1769
        }
1770
      }
14✔
1771
    }
14✔
1772

1773
    rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr);
18!
1774
    if (rc != GNUTLS_E_SUCCESS) {
18!
1775
      throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")");
1776
    }
1777
  }
18✔
1778

1779
  ~GnuTLSIOCtx() override
1780
  {
2✔
1781
    d_creds.reset();
2✔
1782

1783
    if (d_priorityCache) {
2!
1784
      gnutls_priority_deinit(d_priorityCache);
2✔
1785
    }
2✔
1786
  }
2✔
1787

1788
  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
1789
  {
14✔
1790
    handleTicketsKeyRotation(now);
14✔
1791

1792
    std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
14✔
1793
    {
14✔
1794
      ticketsKey = *(d_ticketsKey.read_lock());
14✔
1795
    }
14✔
1796

1797
    auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds, d_priorityCache, ticketsKey, d_enableTickets);
14✔
1798
    if (!d_protos.empty()) {
14!
1799
      connection->setALPNProtos(d_protos);
14✔
1800
    }
14✔
1801
    return connection;
14✔
1802
  }
14✔
1803

1804
  static std::shared_ptr<gnutls_certificate_credentials_st> getPerThreadCredentials(bool validate, const std::string& caStore)
1805
  {
64✔
1806
    static thread_local std::map<std::pair<bool, std::string>, std::shared_ptr<gnutls_certificate_credentials_st>> t_credentials;
64✔
1807
    auto& entry = t_credentials[{validate, caStore}];
64✔
1808
    if (!entry) {
64✔
1809
      gnutls_certificate_credentials_t creds;
18✔
1810
      int rc = gnutls_certificate_allocate_credentials(&creds);
18✔
1811
      if (rc != GNUTLS_E_SUCCESS) {
18!
1812
        throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1813
      }
1814

1815
      entry = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
18✔
1816
      creds = nullptr;
18✔
1817

1818
      if (validate) {
18✔
1819
        if (caStore.empty()) {
13!
1820
          rc = gnutls_certificate_set_x509_system_trust(entry.get());
1821
          if (rc < 0) {
×
1822
            throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1823
          }
1824
        }
1825
        else {
13✔
1826
          rc = gnutls_certificate_set_x509_trust_file(entry.get(), caStore.c_str(), GNUTLS_X509_FMT_PEM);
13✔
1827
          if (rc < 0) {
13!
1828
            throw std::runtime_error("Error adding '" + caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1829
          }
1830
        }
13✔
1831
      }
13✔
1832
    }
18✔
1833
    return entry;
64✔
1834
  }
64✔
1835

1836
  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool, int socket, const struct timeval& timeout) override
1837
  {
64✔
1838
    auto creds = getPerThreadCredentials(d_contextParameters->d_validateCertificates, d_contextParameters->d_caStore);
64✔
1839
    auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, creds, d_priorityCache, d_validateCerts);
64✔
1840
    if (!d_protos.empty()) {
64!
1841
      connection->setALPNProtos(d_protos);
64✔
1842
    }
64✔
1843
    return connection;
64✔
1844
  }
64✔
1845

1846
  void addTicketsKey(time_t now, std::shared_ptr<GnuTLSTicketsKey>&& newKey)
1847
  {
7✔
1848
    if (!d_enableTickets) {
7!
1849
      return;
1850
    }
1851

1852
    {
7✔
1853
      *(d_ticketsKey.write_lock()) = std::move(newKey);
7✔
1854
    }
7✔
1855

1856
    if (d_ticketsKeyRotationDelay > 0) {
7!
1857
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
7✔
1858
    }
7✔
1859

1860
    if (TLSCtx::hasTicketsKeyAddedHook()) {
7✔
1861
      auto ticketsKey = *(d_ticketsKey.read_lock());
2✔
1862
      auto content = ticketsKey->content();
2✔
1863
      TLSCtx::getTicketsKeyAddedHook()(content);
2✔
1864
      safe_memory_release(content.data(), content.size());
2✔
1865
    }
2✔
1866
  }
7✔
1867
  void rotateTicketsKey(time_t now) override
1868
  {
6✔
1869
    if (!d_enableTickets) {
6!
1870
      return;
1871
    }
1872

1873
    auto newKey = std::make_shared<GnuTLSTicketsKey>();
6✔
1874
    addTicketsKey(now, std::move(newKey));
6✔
1875
  }
6✔
1876
  void loadTicketsKey(const std::string& key) final
1877
  {
1✔
1878
    if (!d_enableTickets) {
1!
1879
      return;
1880
    }
1881

1882
    auto newKey = std::make_shared<GnuTLSTicketsKey>(key);
1✔
1883
    addTicketsKey(time(nullptr), std::move(newKey));
1✔
1884
  }
1✔
1885

1886
  void loadTicketsKeys(const std::string& keyFile) final
1887
  {
1888
    if (!d_enableTickets) {
×
1889
      return;
1890
    }
1891

1892
    std::ifstream file(keyFile);
1893
    auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1894
    addTicketsKey(time(nullptr), std::move(newKey));
1895
    file.close();
1896
  }
1897

1898
  size_t getTicketsKeysCount() override
1899
  {
1900
    return *(d_ticketsKey.read_lock()) != nullptr ? 1 : 0;
×
1901
  }
1902

1903
  std::string getName() const override
1904
  {
3✔
1905
    return "gnutls";
3✔
1906
  }
3✔
1907

1908
private:
1909
  /* client context parameters */
1910
  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1911
  const std::vector<std::vector<uint8_t>> d_protos;
1912
  std::unique_ptr<TLSContextParameters> d_contextParameters{nullptr};
1913
  gnutls_priority_t d_priorityCache{nullptr};
1914
  SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};
1915
  bool d_enableTickets{true};
1916
  bool d_validateCerts{true};
1917
};
1918

1919
#endif /* HAVE_GNUTLS */
1920

1921
#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
1922

1923
bool TLSFrontend::setupTLS()
1924
{
82✔
1925
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
82✔
1926
  std::shared_ptr<TLSCtx> newCtx{nullptr};
82✔
1927
  if (d_parentFrontend) {
82✔
1928
    newCtx = d_parentFrontend->getContext();
9✔
1929
    if (newCtx) {
9!
1930
      std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
9✔
1931
      return true;
9✔
1932
    }
9✔
1933
  }
9✔
1934

1935
  /* get the "best" available provider */
1936
#if defined(HAVE_GNUTLS)
73✔
1937
  if (d_provider == "gnutls") {
73✔
1938
    newCtx = std::make_shared<GnuTLSIOCtx>(*this);
6✔
1939
  }
6✔
1940
#endif /* HAVE_GNUTLS */
73✔
1941
#if defined(HAVE_LIBSSL)
73✔
1942
  if (d_provider == "openssl") {
73✔
1943
    newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
22✔
1944
  }
22✔
1945
#endif /* HAVE_LIBSSL */
73✔
1946

1947
  if (!newCtx) {
73✔
1948
#if defined(HAVE_LIBSSL)
45✔
1949
    newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
45✔
1950
#elif defined(HAVE_GNUTLS)
1951
    newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1952
#else
1953
#error "TLS support needed but neither libssl nor GnuTLS were selected"
1954
#endif
1955
  }
45✔
1956

1957
  std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
73✔
1958
#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
73✔
1959
  return true;
73✔
1960
}
82✔
1961

1962
std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params)
1963
{
60✔
1964
#ifdef HAVE_DNS_OVER_TLS
60✔
1965
  /* get the "best" available provider */
1966
  if (!params.d_provider.empty()) {
60✔
1967
#if defined(HAVE_GNUTLS)
48✔
1968
    if (params.d_provider == "gnutls") {
48✔
1969
      return std::make_shared<GnuTLSIOCtx>(params);
18✔
1970
    }
18✔
1971
#endif /* HAVE_GNUTLS */
30✔
1972
#if defined(HAVE_LIBSSL)
35✔
1973
    if (params.d_provider == "openssl") {
35!
1974
      return OpenSSLTLSIOCtx::createClientSideContext(params);
35✔
1975
    }
35✔
1976
#endif /* HAVE_LIBSSL */
35✔
1977
  }
35✔
1978

1979
#if defined(HAVE_LIBSSL)
7✔
1980
  return OpenSSLTLSIOCtx::createClientSideContext(params);
7✔
1981
#elif defined(HAVE_GNUTLS)
1982
  return std::make_shared<GnuTLSIOCtx>(params);
1983
#else
1984
#error "DNS over TLS support needed but neither libssl nor GnuTLS were selected"
1985
#endif
1986

1987
#endif /* HAVE_DNS_OVER_TLS */
×
1988
  return nullptr;
×
1989
}
60✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc