• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 18409242756

10 Oct 2025 02:16PM UTC coverage: 19.38% (-44.8%) from 64.13%
18409242756

push

github

web-flow
Merge pull request #16245 from miodvallat/matriochka_exception

auth: yet another logic botch

3972 of 30808 branches covered (12.89%)

Branch coverage included in aggregate %.

11562 of 49346 relevant lines covered (23.43%)

3168.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pdns/tcpiohandler.cc
1

2
#include "config.h"
3
#include "dolog.hh"
4
#include "iputils.hh"
5
#include "lock.hh"
6
#include "tcpiohandler.hh"
7

8
const bool TCPIOHandler::s_disableConnectForUnitTests = false;
9

10
#ifdef HAVE_LIBSODIUM
11
#include <sodium.h>
12
#endif /* HAVE_LIBSODIUM */
13

14
TLSCtx::tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr};
15

16
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
17
static std::vector<std::vector<uint8_t>> getALPNVector(TLSFrontend::ALPN alpn, bool client)
18
{
×
19
  if (alpn == TLSFrontend::ALPN::DoT) {
×
20
    /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */
21
    return std::vector<std::vector<uint8_t>>{{'d', 'o', 't'}};
×
22
  }
×
23
  if (alpn == TLSFrontend::ALPN::DoH) {
×
24
    if (client) {
×
25
      /* we want to set the ALPN to h2, if only to mitigate the ALPACA attack */
26
      return std::vector<std::vector<uint8_t>>{{'h', '2'}};
×
27
    }
×
28
    /* For server contexts, we want to set the ALPN for DoH (note that h2o sets it own ALPN values):
29
       - HTTP/1.1 so that the OpenSSL callback ALPN accepts it, letting us later return a static response
30
       - HTTP/2
31
    */
32
    return std::vector<std::vector<uint8_t>>{{'h', '2'},{'h', 't', 't', 'p', '/', '1', '.', '1'}};
×
33
  }
×
34
  return {};
×
35
}
×
36

37
#ifdef HAVE_LIBSSL
38

39
namespace {
40
bool shouldDoVerboseLogging()
41
{
×
42
#ifdef DNSDIST
43
  return dnsdist::configuration::getCurrentRuntimeConfiguration().d_verbose;
44
#elif defined(RECURSOR)
45
  return false;
×
46
#else
47
  return true;
48
#endif
49
}
×
50
}
51

52
#include <openssl/conf.h>
53
#include <openssl/err.h>
54
#include <openssl/rand.h>
55
#include <openssl/ssl.h>
56
#include <openssl/x509v3.h>
57

58
#include "libssl.hh"
59

60
static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* arg);
61

62
class OpenSSLFrontendContext
63
{
64
public:
65
  OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
66
  {
×
67
    registerOpenSSLUser();
×
68

69
    auto [ctx, warnings] = libssl_init_server_context(tlsConfig);
×
70
    for (const auto& warning : warnings) {
×
71
      warnlog("%s", warning);
×
72
    }
×
73
    // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): it cannot be initialized before calling libssl_init_server_context()
74
    d_ocspResponses = std::move(ctx.d_ocspResponses);
×
75
    // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): it cannot be initialized before calling libssl_init_server_context()
76
    d_tlsCtx = std::move(ctx.d_defaultContext);
×
77
    // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): it cannot be initialized before calling libssl_init_server_context()
78
    d_sniMap = std::move(ctx.d_sniMap);
×
79
    for (auto& entry : d_sniMap) {
×
80
      SSL_CTX_set_tlsext_servername_callback(entry.second.get(), &sni_server_name_callback);
×
81
    }
×
82

83
    if (!d_tlsCtx) {
×
84
      ERR_print_errors_fp(stderr);
×
85
      throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
×
86
    }
×
87
  }
×
88

89
  void cleanup()
90
  {
×
91
    d_tlsCtx.reset();
×
92

×
93
    unregisterOpenSSLUser();
×
94
  }
×
95

96
  OpenSSLTLSTicketKeysRing d_ticketKeys;
97
  std::map<int, std::string> d_ocspResponses;
98
  pdns::libssl::ServerContext::SNIToContextMap d_sniMap;
99
  std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr};
100
  pdns::UniqueFilePtr d_keyLogFile{nullptr};
101
};
102

103

104
static int sni_server_name_callback(SSL* ssl, int* /* alert */, void* /* arg */)
105
{
×
106
  const auto* serverName = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
×
107
  if (serverName == nullptr) {
×
108
    return SSL_TLSEXT_ERR_NOACK;
×
109
  }
×
110
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
111
  auto* frontendCtx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(ssl));
×
112
  if (frontendCtx == nullptr) {
×
113
    return SSL_TLSEXT_ERR_OK;
×
114
  }
×
115

116
  auto serverNameView = std::string_view(serverName);
×
117

118
  auto mapIt = frontendCtx->d_sniMap.find(serverNameView);
×
119
  if (mapIt == frontendCtx->d_sniMap.end()) {
×
120
    /* keep the default certificate */
121
    return SSL_TLSEXT_ERR_OK;
×
122
  }
×
123

124
  /* if it fails there is nothing we can do,
125
     let's hope OpenSSL will fall back to the existing,
126
     default certificate*/
127
  SSL_set_SSL_CTX(ssl, mapIt->second.get());
×
128
  return SSL_TLSEXT_ERR_OK;
×
129
}
×
130

131
class OpenSSLSession : public TLSSession
132
{
133
public:
134
  OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>&& sess): d_sess(std::move(sess))
135
  {
×
136
  }
×
137

138
  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative()
139
  {
×
140
    return std::move(d_sess);
×
141
  }
×
142

143
private:
144
  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
145
};
146

147
class OpenSSLTLSIOCtx;
148

149
class OpenSSLTLSConnection: public TLSConnection
150
{
151
public:
152
  /* server side connection */
153
  OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_timeout(timeout)
154
  {
×
155
    d_socket = socket;
×
156

157
    if (!d_conn) {
×
158
      vinfolog("Error creating TLS object");
×
159
      if (shouldDoVerboseLogging()) {
×
160
        ERR_print_errors_fp(stderr);
×
161
      }
×
162
      throw std::runtime_error("Error creating TLS object");
×
163
    }
×
164

165
    if (!SSL_set_fd(d_conn.get(), d_socket)) {
×
166
      throw std::runtime_error("Error assigning socket");
×
167
    }
×
168

169
    SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
×
170
  }
×
171

172
  /* client-side connection */
173
  OpenSSLTLSConnection(std::string hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_hostname(std::move(hostname)), d_timeout(timeout), d_isClient(true)
174
  {
×
175
    d_socket = socket;
×
176

177
    if (!d_conn) {
×
178
      vinfolog("Error creating TLS object");
×
179
      if (shouldDoVerboseLogging()) {
×
180
        ERR_print_errors_fp(stderr);
×
181
      }
×
182
      throw std::runtime_error("Error creating TLS object");
×
183
    }
×
184

185
    if (!SSL_set_fd(d_conn.get(), d_socket)) {
×
186
      throw std::runtime_error("Error assigning socket");
×
187
    }
×
188

189
    /* set outgoing Server Name Indication */
190
    if (!d_hostname.empty() && SSL_set_tlsext_host_name(d_conn.get(), d_hostname.c_str()) != 1) {
×
191
      throw std::runtime_error("Error setting TLS SNI to " + d_hostname);
×
192
    }
×
193

194
    if (hostIsAddr) {
×
195
#if (OPENSSL_VERSION_NUMBER >= 0x10002000L)
×
196
      X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
×
197
      /* Enable automatic IP checks */
198
      X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
×
199
      if (X509_VERIFY_PARAM_set1_ip_asc(param, d_hostname.c_str()) != 1) {
×
200
        throw std::runtime_error("Error setting TLS IP for certificate validation");
×
201
      }
×
202
#else
203
      /* no validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
204
#endif
205
    }
×
206
    else {
×
207
#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && defined(HAVE_SSL_SET_HOSTFLAGS) // grrr libressl
×
208
      SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
×
209
      if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
×
210
        throw std::runtime_error("Error setting TLS hostname for certificate validation");
×
211
      }
×
212
#elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
213
      X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
214
      /* Enable automatic hostname checks */
215
      X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
216
      if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) {
217
        throw std::runtime_error("Error setting TLS hostname for certificate validation");
218
      }
219
#else
220
      /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
221
#endif
222
    }
×
223

224
    SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
×
225
  }
×
226

227
  std::vector<int> getAsyncFDs() override
228
  {
×
229
    std::vector<int> results;
×
230
#ifdef SSL_MODE_ASYNC
×
231
    if (SSL_waiting_for_async(d_conn.get()) != 1) {
×
232
      return results;
×
233
    }
×
234

235
    OSSL_ASYNC_FD fds[32];
×
236
    size_t numfds = sizeof(fds)/sizeof(*fds);
×
237
    SSL_get_all_async_fds(d_conn.get(), nullptr, &numfds);
×
238
    if (numfds == 0) {
×
239
      return results;
×
240
    }
×
241

242
    SSL_get_all_async_fds(d_conn.get(), fds, &numfds);
×
243
    results.reserve(numfds);
×
244
    for (size_t idx = 0; idx < numfds; idx++) {
×
245
      results.push_back(fds[idx]);
×
246
    }
×
247
#endif
×
248
    return results;
×
249
  }
×
250

251
  IOState convertIORequestToIOState(int res) const
252
  {
×
253
    int error = SSL_get_error(d_conn.get(), res);
×
254
    if (error == SSL_ERROR_WANT_READ) {
×
255
      return IOState::NeedRead;
×
256
    }
×
257
    else if (error == SSL_ERROR_WANT_WRITE) {
×
258
      return IOState::NeedWrite;
×
259
    }
×
260
    else if (error == SSL_ERROR_SYSCALL) {
×
261
      if (errno == 0) {
×
262
        throw std::runtime_error("TLS connection closed by remote end");
×
263
      }
×
264
      else {
×
265
        throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
×
266
      }
×
267
    }
×
268
    else if (error == SSL_ERROR_ZERO_RETURN) {
×
269
      throw std::runtime_error("TLS connection closed by remote end");
×
270
    }
×
271
#ifdef SSL_MODE_ASYNC
×
272
    else if (error == SSL_ERROR_WANT_ASYNC) {
×
273
      return IOState::Async;
×
274
    }
×
275
#endif
×
276
    else {
×
277
      if (shouldDoVerboseLogging()) {
×
278
        throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error) + ") " + libssl_get_error_string());
×
279
      } else {
×
280
        throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
×
281
      }
×
282
    }
×
283
  }
×
284

285
  void handleIORequest(int res, const struct timeval& timeout)
286
  {
×
287
    auto state = convertIORequestToIOState(res);
×
288
    if (state == IOState::NeedRead) {
×
289
      res = waitForData(d_socket, timeout);
×
290
      if (res == 0) {
×
291
        throw std::runtime_error("Timeout while reading from TLS connection");
×
292
      }
×
293
      else if (res < 0) {
×
294
        throw std::runtime_error("Error waiting to read from TLS connection");
×
295
      }
×
296
    }
×
297
    else if (state == IOState::NeedWrite) {
×
298
      res = waitForRWData(d_socket, false, timeout);
×
299
      if (res == 0) {
×
300
        throw std::runtime_error("Timeout while writing to TLS connection");
×
301
      }
×
302
      else if (res < 0) {
×
303
        throw std::runtime_error("Error waiting to write to TLS connection");
×
304
      }
×
305
    }
×
306
  }
×
307

308
  IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
309
  {
×
310
    /* sorry */
311
    (void) fastOpen;
×
312
    (void) remote;
×
313

314
    int res = SSL_connect(d_conn.get());
×
315
    if (res == 1) {
×
316
      return IOState::Done;
×
317
    }
×
318
    else if (res < 0) {
×
319
      return convertIORequestToIOState(res);
×
320
    }
×
321

322
    throw std::runtime_error("Error establishing a TLS connection");
×
323
  }
×
324

325
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval &timeout) override
326
  {
×
327
    /* sorry */
328
    (void) fastOpen;
×
329
    (void) remote;
×
330

331
    struct timeval start{0,0};
×
332
    struct timeval remainingTime = timeout;
×
333
    if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
×
334
      gettimeofday(&start, nullptr);
×
335
    }
×
336

337
    int res = 0;
×
338
    do {
×
339
      res = SSL_connect(d_conn.get());
×
340
      if (res < 0) {
×
341
        handleIORequest(res, remainingTime);
×
342
      }
×
343

344
      if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
×
345
        struct timeval now;
×
346
        gettimeofday(&now, nullptr);
×
347
        struct timeval elapsed = now - start;
×
348
        if (now < start || remainingTime < elapsed) {
×
349
          throw runtime_error("Timeout while establishing TLS connection");
×
350
        }
×
351
        start = now;
×
352
        remainingTime = remainingTime - elapsed;
×
353
      }
×
354
    }
×
355
    while (res != 1);
×
356
  }
×
357

358
  IOState tryHandshake() override
359
  {
×
360
    if (isClient()) {
×
361
      /* In client mode, the handshake is initiated by the call to SSL_connect()
362
         done from connect()/tryConnect().
363
         In blocking mode it does not return before the handshake has been finished,
364
         and in non-blocking mode calling SSL_connect() once is enough for SSL_write()
365
         and SSL_read() to transparently continue to negotiate the connection after that
366
         (equivalent to doing SSL_set_connect_state() plus trying to write).
367
      */
368
      return IOState::Done;
×
369
    }
×
370

371
    /* As explained above in the client-mode block, we only need to call SSL_accept() once
372
       for SSL_write() and SSL_read() to transparently continue to negotiate the connection after that.
373
       It is equivalent to calling SSL_set_accept_state() plus trying to read.
374
    */
375
    int res = SSL_accept(d_conn.get());
×
376
    if (res == 1) {
×
377
      return IOState::Done;
×
378
    }
×
379
    else if (res < 0) {
×
380
      return convertIORequestToIOState(res);
×
381
    }
×
382

383
    throw std::runtime_error("Error accepting TLS connection");
×
384
  }
×
385

386
  void doHandshake() override
387
  {
×
388
    if (isClient()) {
×
389
      /* we are a client, nothing to do, see the non-blocking version */
390
      return;
×
391
    }
×
392

393
    int res = 0;
×
394
    do {
×
395
      res = SSL_accept(d_conn.get());
×
396
      if (res < 0) {
×
397
        handleIORequest(res, d_timeout);
×
398
      }
×
399
    }
×
400
    while (res < 0);
×
401

402
    if (res != 1) {
×
403
      throw std::runtime_error("Error accepting TLS connection");
×
404
    }
×
405
  }
×
406

407
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
408
  {
×
409
    if (isClient() && !d_connected) {
×
410
      if (d_ktls) {
×
411
        /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
412
        SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get()));
×
413
      }
×
414
    }
×
415

416
    do {
×
417
      int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
×
418
      if (res <= 0) {
×
419
        return convertIORequestToIOState(res);
×
420
      }
×
421
      else {
×
422
        pos += static_cast<size_t>(res);
×
423
      }
×
424
    }
×
425
    while (pos < toWrite);
×
426

427
    if (!d_connected) {
×
428
      d_connected = true;
×
429
    }
×
430

431
    return IOState::Done;
×
432
  }
×
433

434
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
435
  {
×
436
    do {
×
437
      int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
×
438
      if (res <= 0) {
×
439
        return convertIORequestToIOState(res);
×
440
      }
×
441
      else {
×
442
        pos += static_cast<size_t>(res);
×
443
        if (allowIncomplete) {
×
444
          break;
×
445
        }
×
446
      }
×
447
    }
×
448
    while (pos < toRead);
×
449
    return IOState::Done;
×
450
  }
×
451

452
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
453
  {
×
454
    size_t got = 0;
×
455
    struct timeval start = {0, 0};
×
456
    struct timeval remainingTime = totalTimeout;
×
457
    if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
458
      gettimeofday(&start, nullptr);
×
459
    }
×
460

461
    do {
×
462
      int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
×
463
      if (res <= 0) {
×
464
        handleIORequest(res, readTimeout);
×
465
      }
×
466
      else {
×
467
        got += static_cast<size_t>(res);
×
468
        if (allowIncomplete) {
×
469
          break;
×
470
        }
×
471
      }
×
472

473
      if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
474
        struct timeval now;
×
475
        gettimeofday(&now, nullptr);
×
476
        struct timeval elapsed = now - start;
×
477
        if (now < start || remainingTime < elapsed) {
×
478
          throw runtime_error("Timeout while reading data");
×
479
        }
×
480
        start = now;
×
481
        remainingTime = remainingTime - elapsed;
×
482
      }
×
483
    }
×
484
    while (got < bufferSize);
×
485

486
    return got;
×
487
  }
×
488

489
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
490
  {
×
491
    size_t got = 0;
×
492
    do {
×
493
      int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
×
494
      if (res <= 0) {
×
495
        handleIORequest(res, writeTimeout);
×
496
      }
×
497
      else {
×
498
        got += static_cast<size_t>(res);
×
499
      }
×
500
    }
×
501
    while (got < bufferSize);
×
502

503
    return got;
×
504
  }
×
505

506
  bool isUsable() const override
507
  {
×
508
    if (!d_conn) {
×
509
      return false;
×
510
    }
×
511

512
    char buf;
×
513
    int res = SSL_peek(d_conn.get(), &buf, sizeof(buf));
×
514
    if (res > 0) {
×
515
      return true;
×
516
    }
×
517
    try {
×
518
      convertIORequestToIOState(res);
×
519
      return true;
×
520
    }
×
521
    catch (...) {
×
522
      return false;
×
523
    }
×
524

525
    return false;
×
526
  }
×
527

528
  void close() override
529
  {
×
530
    if (d_conn) {
×
531
      SSL_shutdown(d_conn.get());
×
532
    }
×
533
  }
×
534

535
  std::string getServerNameIndication() const override
536
  {
×
537
    if (d_conn) {
×
538
      const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
×
539
      if (value) {
×
540
        return std::string(value);
×
541
      }
×
542
    }
×
543
    return std::string();
×
544
  }
×
545

546
  std::vector<uint8_t> getNextProtocol() const override
547
  {
×
548
    std::vector<uint8_t> result;
×
549
    if (!d_conn) {
×
550
      return result;
×
551
    }
×
552

553
    const unsigned char* alpn = nullptr;
×
554
    unsigned int alpnLen  = 0;
×
555
#ifdef HAVE_SSL_GET0_ALPN_SELECTED
×
556
    if (alpn == nullptr) {
×
557
      SSL_get0_alpn_selected(d_conn.get(), &alpn, &alpnLen);
×
558
    }
×
559
#endif /* HAVE_SSL_GET0_ALPN_SELECTED */
×
560
    if (alpn != nullptr && alpnLen > 0) {
×
561
      result.insert(result.end(), alpn, alpn + alpnLen);
×
562
    }
×
563
    return result;
×
564
  }
×
565

566
  LibsslTLSVersion getTLSVersion() const override
567
  {
×
568
    auto proto = SSL_version(d_conn.get());
×
569
    switch (proto) {
×
570
    case TLS1_VERSION:
×
571
      return LibsslTLSVersion::TLS10;
×
572
    case TLS1_1_VERSION:
×
573
      return LibsslTLSVersion::TLS11;
×
574
    case TLS1_2_VERSION:
×
575
      return LibsslTLSVersion::TLS12;
×
576
#ifdef TLS1_3_VERSION
×
577
    case TLS1_3_VERSION:
×
578
      return LibsslTLSVersion::TLS13;
×
579
#endif /* TLS1_3_VERSION */
×
580
    default:
×
581
      return LibsslTLSVersion::Unknown;
×
582
    }
×
583
  }
×
584

585
  bool hasSessionBeenResumed() const override
586
  {
×
587
    if (d_conn) {
×
588
      return SSL_session_reused(d_conn.get()) != 0;
×
589
    }
×
590
    return false;
×
591
  }
×
592

593
  std::vector<std::unique_ptr<TLSSession>> getSessions() override
594
  {
×
595
    return std::move(d_tlsSessions);
×
596
  }
×
597

598
  void setSession(std::unique_ptr<TLSSession>& session) override
599
  {
×
600
    auto sess = dynamic_cast<OpenSSLSession*>(session.get());
×
601
    if (!sess) {
×
602
      throw std::runtime_error("Unable to convert OpenSSL session");
×
603
    }
×
604

605
    auto native = sess->getNative();
×
606
    auto ret = SSL_set_session(d_conn.get(), native.get());
×
607
    if (ret != 1) {
×
608
      throw std::runtime_error("Error setting up session: " + libssl_get_error_string());
×
609
    }
×
610
    session.reset();
×
611
  }
×
612

613
  void addNewTicket(SSL_SESSION* session)
614
  {
×
615
    d_tlsSessions.push_back(std::make_unique<OpenSSLSession>(std::unique_ptr<SSL_SESSION, void (*)(SSL_SESSION*)>(session, SSL_SESSION_free)));
×
616
  }
×
617

618
  void enableKTLS()
619
  {
×
620
    d_ktls = true;
×
621
  }
×
622

623
  [[nodiscard]] bool isClient() const
624
  {
×
625
    return d_isClient;
×
626
  }
×
627

628
  static void generateConnectionIndexIfNeeded()
629
  {
×
630
    auto init = s_initTLSConnIndex.lock();
×
631
    if (*init == true) {
×
632
      return;
×
633
    }
×
634

635
    /* not initialized yet */
636
    s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
×
637
    if (s_tlsConnIndex == -1) {
×
638
      throw std::runtime_error("Error getting an index for TLS connection data");
×
639
    }
×
640

641
    *init = true;
×
642
  }
×
643

644
  static int getConnectionIndex()
645
  {
×
646
    return s_tlsConnIndex;
×
647
  }
×
648

649
private:
650
  static LockGuarded<bool> s_initTLSConnIndex;
651
  static int s_tlsConnIndex;
652
  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
653
  const std::shared_ptr<const OpenSSLTLSIOCtx> d_tlsCtx; // we need to hold a reference to this to make sure that the context exists for as long as the connection, even if a reload happens in the meantime
654
  std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
655
  const std::string d_hostname;
656
  const timeval d_timeout;
657
  bool d_connected{false};
658
  bool d_ktls{false};
659
  const bool d_isClient{false};
660
};
661

662
LockGuarded<bool> OpenSSLTLSConnection::s_initTLSConnIndex{false};
663
int OpenSSLTLSConnection::s_tlsConnIndex{-1};
664

665
class OpenSSLTLSIOCtx: public TLSCtx, public std::enable_shared_from_this<OpenSSLTLSIOCtx>
666
{
667
  struct Private
668
  {
669
    explicit Private() = default;
670
  };
671

672
public:
673
  static std::shared_ptr<OpenSSLTLSIOCtx> createServerSideContext(TLSFrontend& frontend)
674
  {
×
675
    return std::make_shared<OpenSSLTLSIOCtx>(frontend, Private());
×
676
  }
×
677

678
  static std::shared_ptr<OpenSSLTLSIOCtx> createClientSideContext(const TLSContextParameters& params)
679
  {
×
680
    return std::make_shared<OpenSSLTLSIOCtx>(params, Private());
×
681
  }
×
682

683
  /* server side context */
684
  OpenSSLTLSIOCtx(TLSFrontend& frontend, [[maybe_unused]] Private priv): d_alpnProtos(getALPNVector(frontend.d_alpn, false)), d_feContext(std::make_unique<OpenSSLFrontendContext>(frontend.d_addr, frontend.d_tlsConfig))
685
  {
×
686
    OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
×
687

688
    d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;
×
689

690
    for (auto& entry : d_feContext->d_sniMap) {
×
691
      auto* ctx = entry.second.get();
×
692
      if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) {
×
693
        /* use our own ticket keys handler so we can rotate them */
694
#if OPENSSL_VERSION_MAJOR >= 3
×
695
        SSL_CTX_set_tlsext_ticket_key_evp_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb);
×
696
#else
697
        SSL_CTX_set_tlsext_ticket_key_cb(ctx, &OpenSSLTLSIOCtx::ticketKeyCb);
698
#endif
699
        libssl_set_ticket_key_callback_data(ctx, d_feContext.get());
×
700
      }
×
701

702
#ifndef DISABLE_OCSP_STAPLING
×
703
      if (!d_feContext->d_ocspResponses.empty()) {
×
704
        SSL_CTX_set_tlsext_status_cb(ctx, &OpenSSLTLSIOCtx::ocspStaplingCb);
×
705
        SSL_CTX_set_tlsext_status_arg(ctx, &d_feContext->d_ocspResponses);
×
706
      }
×
707
#endif /* DISABLE_OCSP_STAPLING */
×
708

709
      if (frontend.d_tlsConfig.d_readAhead) {
×
710
        SSL_CTX_set_read_ahead(ctx, 1);
×
711
      }
×
712

713
      libssl_set_error_counters_callback(*ctx, &frontend.d_tlsCounters);
×
714

715
      libssl_set_alpn_select_callback(ctx, alpnServerSelectCallback, this);
×
716

717
      if (!frontend.d_tlsConfig.d_keyLogFile.empty()) {
×
718
        d_feContext->d_keyLogFile = libssl_set_key_log_file(ctx, frontend.d_tlsConfig.d_keyLogFile);
×
719
      }
×
720
    }
×
721

722
    try {
×
723
      if (frontend.d_tlsConfig.d_ticketKeyFile.empty()) {
×
724
        handleTicketsKeyRotation(time(nullptr));
×
725
      }
×
726
      else {
×
727
        OpenSSLTLSIOCtx::loadTicketsKeys(frontend.d_tlsConfig.d_ticketKeyFile);
×
728
      }
×
729
    }
×
730
    catch (const std::exception& e) {
×
731
      throw;
×
732
    }
×
733
  }
×
734

735
  /* client side context */
736
  OpenSSLTLSIOCtx(const TLSContextParameters& params, [[maybe_unused]] Private priv)
737
  {
×
738
    int sslOptions =
×
739
      SSL_OP_NO_SSLv2 |
×
740
      SSL_OP_NO_SSLv3 |
×
741
      SSL_OP_NO_COMPRESSION |
×
742
      SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
×
743
      SSL_OP_SINGLE_DH_USE |
×
744
      SSL_OP_SINGLE_ECDH_USE |
×
745
#ifdef SSL_OP_IGNORE_UNEXPECTED_EOF
×
746
      SSL_OP_IGNORE_UNEXPECTED_EOF |
×
747
#endif
×
748
      SSL_OP_CIPHER_SERVER_PREFERENCE;
×
749
    if (!params.d_enableRenegotiation) {
×
750
#ifdef SSL_OP_NO_RENEGOTIATION
×
751
      sslOptions |= SSL_OP_NO_RENEGOTIATION;
×
752
#elif defined(SSL_OP_NO_CLIENT_RENEGOTIATION)
753
      sslOptions |= SSL_OP_NO_CLIENT_RENEGOTIATION;
754
#endif
755
    }
×
756

757
    if (params.d_ktls) {
×
758
#ifdef SSL_OP_ENABLE_KTLS
×
759
      sslOptions |= SSL_OP_ENABLE_KTLS;
×
760
      d_ktls = true;
×
761
#endif /* SSL_OP_ENABLE_KTLS */
×
762
    }
×
763

764
    registerOpenSSLUser();
×
765

766
    OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
×
767

768
#ifdef HAVE_TLS_CLIENT_METHOD
769
    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
770
#else
771
    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
×
772
#endif
×
773
    if (!d_tlsCtx) {
×
774
      ERR_print_errors_fp(stderr);
×
775
      throw std::runtime_error("Error creating TLS context");
×
776
    }
×
777

778
    SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
×
779
#if defined(SSL_CTX_set_ecdh_auto)
×
780
    SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
×
781
#endif
×
782

783
    if (!params.d_ciphers.empty()) {
×
784
      if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) {
×
785
        ERR_print_errors_fp(stderr);
×
786
        throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context");
×
787
      }
×
788
    }
×
789
#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
×
790
    if (!params.d_ciphers13.empty()) {
×
791
      if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) {
×
792
        ERR_print_errors_fp(stderr);
×
793
        throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context");
×
794
      }
×
795
    }
×
796
#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
×
797

798
    if (params.d_validateCertificates) {
×
799
      if (params.d_caStore.empty())  {
×
800
        if (SSL_CTX_set_default_verify_paths(d_tlsCtx.get()) != 1) {
×
801
          throw std::runtime_error("Error adding the system's default trusted CAs");
×
802
        }
×
803
      } else {
×
804
        if (SSL_CTX_load_verify_locations(d_tlsCtx.get(), params.d_caStore.c_str(), nullptr) != 1) {
×
805
          throw std::runtime_error("Error adding the trusted CAs file " + params.d_caStore);
×
806
        }
×
807
      }
×
808

809
      SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr);
×
810
#if (OPENSSL_VERSION_NUMBER < 0x10002000L)
811
      warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
812
#endif
813
    }
×
814

815
    /* we need to set SSL_SESS_CACHE_CLIENT for the "new ticket" callback (below) to be called,
816
       but we don't want OpenSSL to cache the session itself so we set SSL_SESS_CACHE_NO_INTERNAL_STORE as well */
817
    SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE);
×
818
    SSL_CTX_sess_set_new_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::newTicketFromServerCb);
×
819

820
    if (!params.d_keyLogFile.empty()) {
×
821
      d_keyLogFile = libssl_set_key_log_file(d_tlsCtx.get(), params.d_keyLogFile);
×
822
    }
×
823

824
    libssl_set_alpn_protos(d_tlsCtx.get(), getALPNVector(params.d_alpn, true));
×
825

826
#ifdef SSL_MODE_RELEASE_BUFFERS
×
827
    if (params.d_releaseBuffers) {
×
828
      SSL_CTX_set_mode(d_tlsCtx.get(), SSL_MODE_RELEASE_BUFFERS);
×
829
    }
×
830
#endif
×
831
  }
×
832

833
  OpenSSLTLSIOCtx(const OpenSSLTLSIOCtx&) = delete;
834
  OpenSSLTLSIOCtx(OpenSSLTLSIOCtx&&) = delete;
835
  OpenSSLTLSIOCtx& operator=(const OpenSSLTLSIOCtx&) = delete;
836
  OpenSSLTLSIOCtx& operator=(OpenSSLTLSIOCtx&&) = delete;
837

838
  ~OpenSSLTLSIOCtx() override
839
  {
×
840
    d_tlsCtx.reset();
×
841
    unregisterOpenSSLUser();
×
842
  }
×
843

844
#if OPENSSL_VERSION_MAJOR >= 3
845
  static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx, int enc)
846
#else
847
  static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx, int enc)
848
#endif
849
  {
×
850
    auto* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
×
851
    if (ctx == nullptr) {
×
852
      return -1;
×
853
    }
×
854

855
    int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
×
856
    if (enc == 0) {
×
857
      if (ret == 0 || ret == 2) {
×
858
        auto* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::getConnectionIndex()));
×
859
        if (conn != nullptr) {
×
860
          if (ret == 0) {
×
861
            conn->setUnknownTicketKey();
×
862
          }
×
863
          else if (ret == 2) {
×
864
            conn->setResumedFromInactiveTicketKey();
×
865
          }
×
866
        }
×
867
      }
×
868
    }
×
869

870
    return ret;
×
871
  }
×
872

873
#ifndef DISABLE_OCSP_STAPLING
874
  static int ocspStaplingCb(SSL* ssl, void* arg)
875
  {
×
876
    if (ssl == nullptr || arg == nullptr) {
×
877
      return SSL_TLSEXT_ERR_NOACK;
×
878
    }
×
879
    const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
×
880
    return libssl_ocsp_stapling_callback(ssl, *ocspMap);
×
881
  }
×
882
#endif /* DISABLE_OCSP_STAPLING */
883

884
  static int newTicketFromServerCb(SSL* ssl, SSL_SESSION* session)
885
  {
×
886
    OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(ssl, OpenSSLTLSConnection::getConnectionIndex()));
×
887
    if (session == nullptr || conn == nullptr) {
×
888
      return 0;
×
889
    }
×
890

891
    conn->addNewTicket(session);
×
892
    return 1;
×
893
  }
×
894

895
  SSL_CTX* getOpenSSLContext() const
896
  {
×
897
    if (d_feContext) {
×
898
      return d_feContext->d_tlsCtx.get();
×
899
    }
×
900
    return d_tlsCtx.get();
×
901
  }
×
902

903
  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
904
  {
×
905
    handleTicketsKeyRotation(now);
×
906

907
    return std::make_unique<OpenSSLTLSConnection>(socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
×
908
  }
×
909

910
  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
911
  {
×
912
    auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
×
913
    if (d_ktls) {
×
914
      conn->enableKTLS();
×
915
    }
×
916
    return conn;
×
917
  }
×
918

919
  void rotateTicketsKey(time_t now) override
920
  {
×
921
    d_feContext->d_ticketKeys.rotateTicketsKey(now);
×
922

923
    if (d_ticketsKeyRotationDelay > 0) {
×
924
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
×
925
    }
×
926
  }
×
927

928
  void loadTicketsKeys(const std::string& keyFile) final
929
  {
×
930
    d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
×
931

932
    if (d_ticketsKeyRotationDelay > 0) {
×
933
      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
×
934
    }
×
935
  }
×
936

937
  void loadTicketsKey(const std::string& key) final
938
  {
×
939
    d_feContext->d_ticketKeys.loadTicketsKey(key);
×
940

941
    if (d_ticketsKeyRotationDelay > 0) {
×
942
      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
×
943
    }
×
944
  }
×
945

946
  size_t getTicketsKeysCount() override
947
  {
×
948
    return d_feContext->d_ticketKeys.getKeysCount();
×
949
  }
×
950

951
  std::string getName() const override
952
  {
×
953
    return "openssl";
×
954
  }
×
955

956
  bool isServerContext() const
957
  {
×
958
    return d_feContext != nullptr;
×
959
  }
×
960

961
private:
962
  /* called in a client context, if the client advertised more than one ALPN value and the server returned more than one as well, to select the one to use. */
963
  static int alpnServerSelectCallback(SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
964
  {
×
965
    if (!arg) {
×
966
      return SSL_TLSEXT_ERR_ALERT_WARNING;
×
967
    }
×
968
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
969
    OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
×
970

971
    const pdns::views::UnsignedCharView inView(in, inlen);
×
972
    // Server preference algorithm as per RFC 7301 section 3.2
973
    for (const auto& tentative : obj->d_alpnProtos) {
×
974
      size_t pos = 0;
×
975
      while (pos < inView.size()) {
×
976
        size_t protoLen = inView.at(pos);
×
977
        pos++;
×
978
        if (protoLen > (inlen - pos)) {
×
979
          /* something is very wrong */
980
          return SSL_TLSEXT_ERR_ALERT_WARNING;
×
981
        }
×
982

983
        if (tentative.size() == protoLen && memcmp(&inView.at(pos), tentative.data(), tentative.size()) == 0) {
×
984
          *out = &inView.at(pos);
×
985
          *outlen = protoLen;
×
986
          return SSL_TLSEXT_ERR_OK;
×
987
        }
×
988
        pos += protoLen;
×
989
      }
×
990
    }
×
991

992
    return SSL_TLSEXT_ERR_NOACK;
×
993
  }
×
994

995
  const std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
996
  std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
997
  std::unique_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
998
  pdns::UniqueFilePtr d_keyLogFile{nullptr};
999
  bool d_ktls{false};
1000
};
1001

1002
#endif /* HAVE_LIBSSL */
1003

1004
#ifdef HAVE_GNUTLS
1005
#include <gnutls/gnutls.h>
1006
#include <gnutls/x509.h>
1007

1008
static void safe_memory_lock([[maybe_unused]] void* data, [[maybe_unused]] size_t size)
1009
{
1010
#ifdef HAVE_LIBSODIUM
1011
  sodium_mlock(data, size);
1012
#endif
1013
}
1014

1015
static void safe_memory_release(void* data, size_t size)
1016
{
1017
#ifdef HAVE_LIBSODIUM
1018
  sodium_munlock(data, size);
1019
#elif defined(HAVE_EXPLICIT_BZERO)
1020
  explicit_bzero(data, size);
1021
#elif defined(HAVE_EXPLICIT_MEMSET)
1022
  explicit_memset(data, 0, size);
1023
#elif defined(HAVE_GNUTLS_MEMSET)
1024
  gnutls_memset(data, 0, size);
1025
#else
1026
  /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
1027
  volatile unsigned int volatile_zero_idx = 0;
1028
  volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
1029

1030
  if (size == 0)
1031
    return;
1032

1033
  do {
1034
    memset(data, 0, size);
1035
  } while (p[volatile_zero_idx] != 0);
1036
#endif
1037
}
1038

1039
class GnuTLSTicketsKey
1040
{
1041
public:
1042
  GnuTLSTicketsKey()
1043
  {
1044
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
1045
      throw std::runtime_error("Error generating tickets key for TLS context");
1046
    }
1047

1048
    safe_memory_lock(d_key.data, d_key.size);
1049
  }
1050

1051
  GnuTLSTicketsKey(const std::string& key)
1052
  {
1053
    /* to be sure we are loading the correct amount of data, which
1054
       may change between versions, let's generate a correct key first */
1055
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
1056
      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
1057
    }
1058

1059
    safe_memory_lock(d_key.data, d_key.size);
1060
    if (key.size() != d_key.size) {
1061
      safe_memory_release(d_key.data, d_key.size);
1062
      gnutls_free(d_key.data);
1063
      d_key.data = nullptr;
1064
      throw std::runtime_error("Invalid GnuTLS ticket key size");
1065
    }
1066
    memcpy(d_key.data, key.data(), key.size());
1067
  }
1068
  GnuTLSTicketsKey(std::ifstream& file)
1069
  {
1070
    /* to be sure we are loading the correct amount of data, which
1071
       may change between versions, let's generate a correct key first */
1072
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
1073
      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
1074
    }
1075

1076
    safe_memory_lock(d_key.data, d_key.size);
1077

1078
    try {
1079
      file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
1080

1081
      if (file.fail()) {
1082
        throw std::runtime_error("Invalid GnuTLS tickets key file");
1083
      }
1084

1085
    }
1086
    catch (const std::exception& e) {
1087
      safe_memory_release(d_key.data, d_key.size);
1088
      gnutls_free(d_key.data);
1089
      d_key.data = nullptr;
1090
      throw;
1091
    }
1092
  }
1093
  [[nodiscard]] std::string content() const
1094
  {
1095
    std::string result{};
1096
    if (d_key.data != nullptr && d_key.size > 0) {
1097
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
1098
      result.append(reinterpret_cast<const char*>(d_key.data), d_key.size);
1099
      safe_memory_lock(result.data(), result.size());
1100
    }
1101
    return result;
1102
  }
1103

1104
  ~GnuTLSTicketsKey()
1105
  {
1106
    if (d_key.data != nullptr && d_key.size > 0) {
1107
      safe_memory_release(d_key.data, d_key.size);
1108
    }
1109
    gnutls_free(d_key.data);
1110
    d_key.data = nullptr;
1111
  }
1112
  const gnutls_datum_t& getKey() const
1113
  {
1114
    return d_key;
1115
  }
1116

1117
private:
1118
  gnutls_datum_t d_key{nullptr, 0};
1119
};
1120

1121
class GnuTLSSession : public TLSSession
1122
{
1123
public:
1124
  GnuTLSSession(gnutls_datum_t& sess): d_sess(sess)
1125
  {
1126
    sess.data = nullptr;
1127
    sess.size = 0;
1128
  }
1129

1130
  ~GnuTLSSession() override
1131
  {
1132
    if (d_sess.data != nullptr && d_sess.size > 0) {
1133
      safe_memory_release(d_sess.data, d_sess.size);
1134
    }
1135
    gnutls_free(d_sess.data);
1136
    d_sess.data = nullptr;
1137
  }
1138

1139
  const gnutls_datum_t& getNative()
1140
  {
1141
    return d_sess;
1142
  }
1143

1144
private:
1145
  gnutls_datum_t d_sess{nullptr, 0};
1146
};
1147

1148
class GnuTLSConnection: public TLSConnection
1149
{
1150
public:
1151
  /* server side connection */
1152
  GnuTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_creds(creds), d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
1153
  {
1154
    unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
1155
#ifdef GNUTLS_NO_SIGNAL
1156
    sslOptions |= GNUTLS_NO_SIGNAL;
1157
#endif
1158

1159
    d_socket = socket;
1160

1161
    gnutls_session_t conn;
1162
    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
1163
      throw std::runtime_error("Error creating TLS connection");
1164
    }
1165

1166
    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
1167
    conn = nullptr;
1168

1169
    if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()) != GNUTLS_E_SUCCESS) {
1170
      throw std::runtime_error("Error setting certificate and key to TLS connection");
1171
    }
1172

1173
    if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
1174
      throw std::runtime_error("Error setting ciphers to TLS connection");
1175
    }
1176

1177
    if (enableTickets && d_ticketsKey) {
1178
      const gnutls_datum_t& key = d_ticketsKey->getKey();
1179
      if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
1180
        throw std::runtime_error("Error setting the tickets key to TLS connection");
1181
      }
1182
    }
1183

1184
    gnutls_transport_set_int(d_conn.get(), d_socket);
1185

1186
    /* timeouts are in milliseconds */
1187
    gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
1188
    gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
1189
  }
1190

1191
  /* client-side connection */
1192
  GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, bool validateCerts): d_creds(creds), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
1193
  {
1194
    unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
1195
#ifdef GNUTLS_NO_SIGNAL
1196
    sslOptions |= GNUTLS_NO_SIGNAL;
1197
#endif
1198

1199
    d_socket = socket;
1200

1201
    gnutls_session_t conn;
1202
    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
1203
      throw std::runtime_error("Error creating TLS connection");
1204
    }
1205

1206
    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
1207
    conn = nullptr;
1208

1209
    int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get());
1210
    if (rc != GNUTLS_E_SUCCESS) {
1211
      throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
1212
    }
1213

1214
    rc = gnutls_priority_set(d_conn.get(), priorityCache);
1215
    if (rc != GNUTLS_E_SUCCESS) {
1216
      throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc)));
1217
    }
1218

1219
    gnutls_transport_set_int(d_conn.get(), d_socket);
1220

1221
    /* timeouts are in milliseconds */
1222
    gnutls_handshake_set_timeout(d_conn.get(),  timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
1223
    gnutls_record_set_timeout(d_conn.get(),  timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
1224

1225
#ifdef HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
1226
    if (validateCerts && !d_host.empty()) {
1227
      gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
1228
      rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
1229
      if (rc != GNUTLS_E_SUCCESS) {
1230
        throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc)));
1231
      }
1232
    }
1233
#else
1234
    /* no hostname validation for you */
1235
#endif
1236

1237
    /* allow access to our data in the callbacks */
1238
    gnutls_session_set_ptr(d_conn.get(), this);
1239
    gnutls_handshake_set_hook_function(d_conn.get(), GNUTLS_HANDSHAKE_NEW_SESSION_TICKET, GNUTLS_HOOK_POST, newTicketFromServerCb);
1240
  }
1241

1242
  /* The callback prototype changed in 3.4.0. */
1243
#if GNUTLS_VERSION_NUMBER >= 0x030400
1244
  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */, const gnutls_datum_t* /* msg */)
1245
#else
1246
  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */)
1247
#endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */
1248
  {
1249
    if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) {
1250
      return 0;
1251
    }
1252

1253
    GnuTLSConnection* conn = reinterpret_cast<GnuTLSConnection*>(gnutls_session_get_ptr(session));
1254
    if (conn == nullptr) {
1255
      return 0;
1256
    }
1257

1258
    gnutls_datum_t sess{nullptr, 0};
1259
    auto ret = gnutls_session_get_data2(session, &sess);
1260
    /* GnuTLS returns a 'fake' ticket of 4 bytes set to zero when there is no ticket available */
1261
    if (ret != GNUTLS_E_SUCCESS || sess.size <= 4) {
1262
      throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
1263
    }
1264
    conn->d_tlsSessions.push_back(std::make_unique<GnuTLSSession>(sess));
1265
    return 0;
1266
  }
1267

1268
  IOState tryConnect(bool fastOpen, [[maybe_unused]] const ComboAddress& remote) override
1269
  {
1270
    int ret = 0;
1271

1272
    if (fastOpen) {
1273
#ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
1274
      gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast<struct sockaddr*>(reinterpret_cast<const struct sockaddr*>(&remote)), remote.getSocklen(), 0);
1275
#endif
1276
    }
1277

1278
    do {
1279
      ret = gnutls_handshake(d_conn.get());
1280
      if (ret == GNUTLS_E_SUCCESS) {
1281
        d_handshakeDone = true;
1282
        return IOState::Done;
1283
      }
1284
      else if (ret == GNUTLS_E_AGAIN) {
1285
        int direction = gnutls_record_get_direction(d_conn.get());
1286
        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
1287
      }
1288
      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
1289
        throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1290
      }
1291
    } while (ret == GNUTLS_E_INTERRUPTED);
1292

1293
    throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1294
  }
1295

1296
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override
1297
  {
1298
    struct timeval start = {0, 0};
1299
    struct timeval remainingTime = timeout;
1300
    if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
1301
      gettimeofday(&start, nullptr);
1302
    }
1303

1304
    IOState state;
1305
    do {
1306
      state = tryConnect(fastOpen, remote);
1307
      if (state == IOState::Done) {
1308
        return;
1309
      }
1310
      else if (state == IOState::NeedRead) {
1311
        int result = waitForData(d_socket, remainingTime);
1312
        if (result <= 0) {
1313
          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
1314
        }
1315
      }
1316
      else if (state == IOState::NeedWrite) {
1317
        int result = waitForRWData(d_socket, false, remainingTime);
1318
        if (result <= 0) {
1319
          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
1320
        }
1321
      }
1322

1323
      if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
1324
        struct timeval now;
1325
        gettimeofday(&now, nullptr);
1326
        struct timeval elapsed = now - start;
1327
        if (now < start || remainingTime < elapsed) {
1328
          throw runtime_error("Timeout while establishing TLS connection");
1329
        }
1330
        start = now;
1331
        remainingTime = remainingTime - elapsed;
1332
      }
1333
    }
1334
    while (state != IOState::Done);
1335
  }
1336

1337
  void doHandshake() override
1338
  {
1339
    int ret = 0;
1340
    do {
1341
      ret = gnutls_handshake(d_conn.get());
1342
      if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
1343
        if (d_client) {
1344
          throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1345
        }
1346
        else {
1347
          throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
1348
        }
1349
      }
1350
    }
1351
    while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED);
1352

1353
    d_handshakeDone = true;
1354
  }
1355

1356
  IOState tryHandshake() override
1357
  {
1358
    int ret = 0;
1359

1360
    do {
1361
      ret = gnutls_handshake(d_conn.get());
1362
      if (ret == GNUTLS_E_SUCCESS) {
1363
        d_handshakeDone = true;
1364
        return IOState::Done;
1365
      }
1366
      else if (ret == GNUTLS_E_AGAIN) {
1367
        int direction = gnutls_record_get_direction(d_conn.get());
1368
        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
1369
      }
1370
      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
1371
        if (d_client) {
1372
          std::string error;
1373
#ifdef HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS
1374
          if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) {
1375
            gnutls_datum_t out;
1376
            if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn.get()), gnutls_certificate_type_get(d_conn.get()), &out, 0) == 0) {
1377
              error = " (" + std::string(reinterpret_cast<const char*>(out.data)) + ")";
1378
              gnutls_free(out.data);
1379
            }
1380
          }
1381
#endif /* HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS */
1382
          throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)) + error);
1383
        }
1384
        else {
1385
          throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
1386
        }
1387
      }
1388
    } while (ret == GNUTLS_E_INTERRUPTED);
1389

1390
    if (d_client) {
1391
      throw std::runtime_error("Error establishinging a new connection: " + std::string(gnutls_strerror(ret)));
1392
    }
1393
    else {
1394
      throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
1395
    }
1396
  }
1397

1398
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
1399
  {
1400
    if (!d_handshakeDone) {
1401
      /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1402
         we need to keep calling gnutls_handshake() until the handshake has been finished. */
1403
      auto state = tryHandshake();
1404
      if (state != IOState::Done) {
1405
        return state;
1406
      }
1407
    }
1408

1409
    do {
1410
      ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
1411
      if (res == 0) {
1412
        throw std::runtime_error("Error writing to TLS connection");
1413
      }
1414
      else if (res > 0) {
1415
        pos += static_cast<size_t>(res);
1416
      }
1417
      else if (res < 0) {
1418
        if (gnutls_error_is_fatal(res)) {
1419
          throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
1420
        }
1421
        else if (res == GNUTLS_E_AGAIN) {
1422
          return IOState::NeedWrite;
1423
        }
1424
        vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
1425
      }
1426
    }
1427
    while (pos < toWrite);
1428
    return IOState::Done;
1429
  }
1430

1431
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
1432
  {
1433
    if (!d_handshakeDone) {
1434
      /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1435
         we need to keep calling gnutls_handshake() until the handshake has been finished. */
1436
      auto state = tryHandshake();
1437
      if (state != IOState::Done) {
1438
        return state;
1439
      }
1440
    }
1441

1442
    do {
1443
      ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
1444
      if (res == 0) {
1445
        throw std::runtime_error("EOF while reading from TLS connection");
1446
      }
1447
      else if (res > 0) {
1448
        pos += static_cast<size_t>(res);
1449
        if (allowIncomplete) {
1450
          break;
1451
        }
1452
      }
1453
      else if (res < 0) {
1454
        if (gnutls_error_is_fatal(res)) {
1455
          throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
1456
        }
1457
        else if (res == GNUTLS_E_AGAIN) {
1458
          return IOState::NeedRead;
1459
        }
1460
        vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
1461
      }
1462
    }
1463
    while (pos < toRead);
1464
    return IOState::Done;
1465
  }
1466

1467
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
1468
  {
1469
    size_t got = 0;
1470
    struct timeval start{0,0};
1471
    struct timeval  remainingTime = totalTimeout;
1472
    if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
1473
      gettimeofday(&start, nullptr);
1474
    }
1475

1476
    do {
1477
      ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
1478
      if (res == 0) {
1479
        throw std::runtime_error("EOF while reading from TLS connection");
1480
      }
1481
      else if (res > 0) {
1482
        got += static_cast<size_t>(res);
1483
        if (allowIncomplete) {
1484
          break;
1485
        }
1486
      }
1487
      else if (res < 0) {
1488
        if (gnutls_error_is_fatal(res)) {
1489
          throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
1490
        }
1491
        else if (res == GNUTLS_E_AGAIN) {
1492
          int result = waitForData(d_socket, readTimeout);
1493
          if (result <= 0) {
1494
            throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
1495
          }
1496
        }
1497
        else {
1498
          vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
1499
        }
1500
      }
1501

1502
      if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
1503
        struct timeval now;
1504
        gettimeofday(&now, nullptr);
1505
        struct timeval elapsed = now - start;
1506
        if (now < start || remainingTime < elapsed) {
1507
          throw runtime_error("Timeout while reading data");
1508
        }
1509
        start = now;
1510
        remainingTime = remainingTime - elapsed;
1511
      }
1512
    }
1513
    while (got < bufferSize);
1514

1515
    return got;
1516
  }
1517

1518
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
1519
  {
1520
    size_t got = 0;
1521

1522
    do {
1523
      ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
1524
      if (res == 0) {
1525
        throw std::runtime_error("Error writing to TLS connection");
1526
      }
1527
      else if (res > 0) {
1528
        got += static_cast<size_t>(res);
1529
      }
1530
      else if (res < 0) {
1531
        if (gnutls_error_is_fatal(res)) {
1532
          throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
1533
        }
1534
        else if (res == GNUTLS_E_AGAIN) {
1535
          int result = waitForRWData(d_socket, false, writeTimeout);
1536
          if (result <= 0) {
1537
            throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
1538
          }
1539
        }
1540
        else {
1541
          vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
1542
        }
1543
      }
1544
    }
1545
    while (got < bufferSize);
1546

1547
    return got;
1548
  }
1549

1550
  bool isUsable() const override
1551
  {
1552
    if (!d_conn) {
1553
      return false;
1554
    }
1555

1556
    /* as far as I can tell we can't peek so we cannot do better */
1557
    return isTCPSocketUsable(d_socket);
1558
  }
1559

1560
  std::string getServerNameIndication() const override
1561
  {
1562
    if (d_conn) {
1563
      unsigned int type;
1564
      size_t name_len = 256;
1565
      std::string sni;
1566
      sni.resize(name_len);
1567

1568
      int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
1569
      if (res == GNUTLS_E_SUCCESS) {
1570
        sni.resize(name_len);
1571
        return sni;
1572
      }
1573
    }
1574
    return std::string();
1575
  }
1576

1577
  std::vector<uint8_t> getNextProtocol() const override
1578
  {
1579
    std::vector<uint8_t> result;
1580
    if (!d_conn) {
1581
      return result;
1582
    }
1583
    gnutls_datum_t next;
1584
    if (gnutls_alpn_get_selected_protocol(d_conn.get(), &next) != GNUTLS_E_SUCCESS) {
1585
      return result;
1586
    }
1587
    result.insert(result.end(), next.data, next.data + next.size);
1588
    return result;
1589
  }
1590

1591
  LibsslTLSVersion getTLSVersion() const override
1592
  {
1593
    auto proto = gnutls_protocol_get_version(d_conn.get());
1594
    switch (proto) {
1595
    case GNUTLS_TLS1_0:
1596
      return LibsslTLSVersion::TLS10;
1597
    case GNUTLS_TLS1_1:
1598
      return LibsslTLSVersion::TLS11;
1599
    case GNUTLS_TLS1_2:
1600
      return LibsslTLSVersion::TLS12;
1601
#if GNUTLS_VERSION_NUMBER >= 0x030603
1602
    case GNUTLS_TLS1_3:
1603
      return LibsslTLSVersion::TLS13;
1604
#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
1605
    default:
1606
      return LibsslTLSVersion::Unknown;
1607
    }
1608
  }
1609

1610
  bool hasSessionBeenResumed() const override
1611
  {
1612
    if (d_conn) {
1613
      return gnutls_session_is_resumed(d_conn.get()) != 0;
1614
    }
1615
    return false;
1616
  }
1617

1618
  std::vector<std::unique_ptr<TLSSession>> getSessions() override
1619
  {
1620
    return std::move(d_tlsSessions);
1621
  }
1622

1623
  void setSession(std::unique_ptr<TLSSession>& session) override
1624
  {
1625
    auto sess = dynamic_cast<GnuTLSSession*>(session.get());
1626
    if (!sess) {
1627
      throw std::runtime_error("Unable to convert GnuTLS session");
1628
    }
1629

1630
    auto native = sess->getNative();
1631
    auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size);
1632
    if (ret != GNUTLS_E_SUCCESS) {
1633
      throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret)));
1634
    }
1635
    session.reset();
1636
  }
1637

1638
  void close() override
1639
  {
1640
    if (d_conn) {
1641
      gnutls_bye(d_conn.get(), GNUTLS_SHUT_RDWR);
1642
    }
1643
  }
1644

1645
  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos)
1646
  {
1647
    std::vector<gnutls_datum_t> values;
1648
    values.reserve(protos.size());
1649
    for (const auto& proto : protos) {
1650
      gnutls_datum_t value;
1651
      value.data = const_cast<uint8_t*>(proto.data());
1652
      value.size = proto.size();
1653
      values.push_back(value);
1654
    }
1655
    unsigned int flags = 0;
1656
#if GNUTLS_VERSION_NUMBER >= 0x030500
1657
    flags |= GNUTLS_ALPN_MANDATORY;
1658
#elif defined(GNUTLS_ALPN_MAND)
1659
    flags |= GNUTLS_ALPN_MAND;
1660
#endif
1661
    return gnutls_alpn_set_protocols(d_conn.get(), values.data(), values.size(), flags);
1662
  }
1663

1664
  std::vector<int> getAsyncFDs() override
1665
  {
1666
    return {};
1667
  }
1668

1669
private:
1670
  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1671
  std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
1672
  std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
1673
  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
1674
  std::string d_host;
1675
  const bool d_client{false};
1676
  bool d_handshakeDone{false};
1677
};
1678

1679
class GnuTLSIOCtx: public TLSCtx
1680
{
1681
public:
1682
  /* server side context */
1683
  GnuTLSIOCtx(TLSFrontend& frontend): d_protos(getALPNVector(frontend.d_alpn, false)), d_enableTickets(frontend.d_tlsConfig.d_enableTickets)
1684
  {
1685
    int rc = 0;
1686
    d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;
1687

1688
    gnutls_certificate_credentials_t creds;
1689
    rc = gnutls_certificate_allocate_credentials(&creds);
1690
    if (rc != GNUTLS_E_SUCCESS) {
1691
      throw std::runtime_error("Error allocating credentials for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1692
    }
1693

1694
    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
1695
    creds = nullptr;
1696

1697
    for (const auto& pair : frontend.d_tlsConfig.d_certKeyPairs) {
1698
      rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.d_cert.c_str(), pair.d_key->c_str(), GNUTLS_X509_FMT_PEM);
1699
      if (rc != GNUTLS_E_SUCCESS) {
1700
        throw std::runtime_error("Error loading certificate ('" + pair.d_cert + "') and key ('" + pair.d_key.value() + "') for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1701
      }
1702
    }
1703

1704
#ifndef DISABLE_OCSP_STAPLING
1705
    size_t count = 0;
1706
    for (const auto& file : frontend.d_tlsConfig.d_ocspFiles) {
1707
      rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
1708
      if (rc != GNUTLS_E_SUCCESS) {
1709
        warnlog("Error loading OCSP response from file '%s' for certificate ('%s') and key ('%s') for TLS context on %s: %s", file, frontend.d_tlsConfig.d_certKeyPairs.at(count).d_cert, frontend.d_tlsConfig.d_certKeyPairs.at(count).d_key.value(), frontend.d_addr.toStringWithPort(), gnutls_strerror(rc));
1710
      }
1711
      ++count;
1712
    }
1713
#endif /* DISABLE_OCSP_STAPLING */
1714

1715
#if GNUTLS_VERSION_NUMBER >= 0x030600
1716
    rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
1717
    if (rc != GNUTLS_E_SUCCESS) {
1718
      throw std::runtime_error("Error setting DH params for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
1719
    }
1720
#endif
1721

1722
    rc = gnutls_priority_init(&d_priorityCache, frontend.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : frontend.d_tlsConfig.d_ciphers.c_str(), nullptr);
1723
    if (rc != GNUTLS_E_SUCCESS) {
1724
      throw std::runtime_error("Error setting up TLS cipher preferences to '" + frontend.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + frontend.d_addr.toStringWithPort());
1725
    }
1726

1727
    try {
1728
      if (frontend.d_tlsConfig.d_ticketKeyFile.empty()) {
1729
        handleTicketsKeyRotation(time(nullptr));
1730
      }
1731
      else {
1732
        GnuTLSIOCtx::loadTicketsKeys(frontend.d_tlsConfig.d_ticketKeyFile);
1733
      }
1734
    }
1735
    catch(const std::runtime_error& e) {
1736
      throw std::runtime_error("Error generating tickets key for TLS context on " + frontend.d_addr.toStringWithPort() + ": " + e.what());
1737
    }
1738
  }
1739

1740
  /* client side context */
1741
  GnuTLSIOCtx(const TLSContextParameters& params): d_protos(getALPNVector(params.d_alpn, true)), d_contextParameters(std::make_unique<TLSContextParameters>(params)), d_validateCerts(params.d_validateCertificates)
1742
  {
1743
    int rc = 0;
1744

1745
    gnutls_certificate_credentials_t creds;
1746
    rc = gnutls_certificate_allocate_credentials(&creds);
1747
    if (rc != GNUTLS_E_SUCCESS) {
1748
      throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1749
    }
1750

1751
    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
1752
    creds = nullptr;
1753

1754
    if (params.d_validateCertificates) {
1755
      if (params.d_caStore.empty()) {
1756
#if GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703
1757
        /* see https://gitlab.com/gnutls/gnutls/-/issues/1277 */
1758
        std::cerr<<"Warning: GnuTLS 3.7.0 - 3.7.2 have a memory leak when validating server certificates in some configurations (PKCS11 support enabled, and a default PKCS11 trust store), please consider upgrading GnuTLS, using the OpenSSL provider for outgoing connections, or explicitly setting a CA store"<<std::endl;
1759
#endif /* GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703 */
1760
        rc = gnutls_certificate_set_x509_system_trust(d_creds.get());
1761
        if (rc < 0) {
1762
          throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1763
        }
1764
      }
1765
      else {
1766
        rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM);
1767
        if (rc < 0) {
1768
          throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1769
        }
1770
      }
1771
    }
1772

1773
    rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr);
1774
    if (rc != GNUTLS_E_SUCCESS) {
1775
      throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")");
1776
    }
1777
  }
1778

1779
  ~GnuTLSIOCtx() override
1780
  {
1781
    d_creds.reset();
1782

1783
    if (d_priorityCache) {
1784
      gnutls_priority_deinit(d_priorityCache);
1785
    }
1786
  }
1787

1788
  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
1789
  {
1790
    handleTicketsKeyRotation(now);
1791

1792
    std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
1793
    {
1794
      ticketsKey = *(d_ticketsKey.read_lock());
1795
    }
1796

1797
    auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds, d_priorityCache, ticketsKey, d_enableTickets);
1798
    if (!d_protos.empty()) {
1799
      connection->setALPNProtos(d_protos);
1800
    }
1801
    return connection;
1802
  }
1803

1804
  static std::shared_ptr<gnutls_certificate_credentials_st> getPerThreadCredentials(bool validate, const std::string& caStore)
1805
  {
1806
    static thread_local std::map<std::pair<bool, std::string>, std::shared_ptr<gnutls_certificate_credentials_st>> t_credentials;
1807
    auto& entry = t_credentials[{validate, caStore}];
1808
    if (!entry) {
1809
      gnutls_certificate_credentials_t creds;
1810
      int rc = gnutls_certificate_allocate_credentials(&creds);
1811
      if (rc != GNUTLS_E_SUCCESS) {
1812
        throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
1813
      }
1814

1815
      entry = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
1816
      creds = nullptr;
1817

1818
      if (validate) {
1819
        if (caStore.empty()) {
1820
          rc = gnutls_certificate_set_x509_system_trust(entry.get());
1821
          if (rc < 0) {
1822
            throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
1823
          }
1824
        }
1825
        else {
1826
          rc = gnutls_certificate_set_x509_trust_file(entry.get(), caStore.c_str(), GNUTLS_X509_FMT_PEM);
1827
          if (rc < 0) {
1828
            throw std::runtime_error("Error adding '" + caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
1829
          }
1830
        }
1831
      }
1832
    }
1833
    return entry;
1834
  }
1835

1836
  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool, int socket, const struct timeval& timeout) override
1837
  {
1838
    auto creds = getPerThreadCredentials(d_contextParameters->d_validateCertificates, d_contextParameters->d_caStore);
1839
    auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, creds, d_priorityCache, d_validateCerts);
1840
    if (!d_protos.empty()) {
1841
      connection->setALPNProtos(d_protos);
1842
    }
1843
    return connection;
1844
  }
1845

1846
  void addTicketsKey(time_t now, std::shared_ptr<GnuTLSTicketsKey>&& newKey)
1847
  {
1848
    if (!d_enableTickets) {
1849
      return;
1850
    }
1851

1852
    {
1853
      *(d_ticketsKey.write_lock()) = std::move(newKey);
1854
    }
1855

1856
    if (d_ticketsKeyRotationDelay > 0) {
1857
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
1858
    }
1859

1860
    if (TLSCtx::hasTicketsKeyAddedHook()) {
1861
      auto ticketsKey = *(d_ticketsKey.read_lock());
1862
      auto content = ticketsKey->content();
1863
      TLSCtx::getTicketsKeyAddedHook()(content);
1864
      safe_memory_release(content.data(), content.size());
1865
    }
1866
  }
1867
  void rotateTicketsKey(time_t now) override
1868
  {
1869
    if (!d_enableTickets) {
1870
      return;
1871
    }
1872

1873
    auto newKey = std::make_shared<GnuTLSTicketsKey>();
1874
    addTicketsKey(now, std::move(newKey));
1875
  }
1876
  void loadTicketsKey(const std::string& key) final
1877
  {
1878
    if (!d_enableTickets) {
1879
      return;
1880
    }
1881

1882
    auto newKey = std::make_shared<GnuTLSTicketsKey>(key);
1883
    addTicketsKey(time(nullptr), std::move(newKey));
1884
  }
1885

1886
  void loadTicketsKeys(const std::string& keyFile) final
1887
  {
1888
    if (!d_enableTickets) {
1889
      return;
1890
    }
1891

1892
    std::ifstream file(keyFile);
1893
    auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
1894
    addTicketsKey(time(nullptr), std::move(newKey));
1895
    file.close();
1896
  }
1897

1898
  size_t getTicketsKeysCount() override
1899
  {
1900
    return *(d_ticketsKey.read_lock()) != nullptr ? 1 : 0;
1901
  }
1902

1903
  std::string getName() const override
1904
  {
1905
    return "gnutls";
1906
  }
1907

1908
private:
1909
  /* client context parameters */
1910
  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1911
  const std::vector<std::vector<uint8_t>> d_protos;
1912
  std::unique_ptr<TLSContextParameters> d_contextParameters{nullptr};
1913
  gnutls_priority_t d_priorityCache{nullptr};
1914
  SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};
1915
  bool d_enableTickets{true};
1916
  bool d_validateCerts{true};
1917
};
1918

1919
#endif /* HAVE_GNUTLS */
1920

1921
#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
1922

1923
bool TLSFrontend::setupTLS()
1924
{
×
1925
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
×
1926
  std::shared_ptr<TLSCtx> newCtx{nullptr};
×
1927
  if (d_parentFrontend) {
×
1928
    newCtx = d_parentFrontend->getContext();
×
1929
    if (newCtx) {
×
1930
      std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
×
1931
      return true;
×
1932
    }
×
1933
  }
×
1934

1935
  /* get the "best" available provider */
1936
#if defined(HAVE_GNUTLS)
1937
  if (d_provider == "gnutls") {
1938
    newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1939
  }
1940
#endif /* HAVE_GNUTLS */
1941
#if defined(HAVE_LIBSSL)
×
1942
  if (d_provider == "openssl") {
×
1943
    newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
×
1944
  }
×
1945
#endif /* HAVE_LIBSSL */
×
1946

1947
  if (!newCtx) {
×
1948
#if defined(HAVE_LIBSSL)
×
1949
    newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
×
1950
#elif defined(HAVE_GNUTLS)
1951
    newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1952
#else
1953
#error "TLS support needed but neither libssl nor GnuTLS were selected"
1954
#endif
1955
  }
×
1956

1957
  std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
×
1958
#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
×
1959
  return true;
×
1960
}
×
1961

1962
std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params)
1963
{
×
1964
#ifdef HAVE_DNS_OVER_TLS
×
1965
  /* get the "best" available provider */
1966
  if (!params.d_provider.empty()) {
×
1967
#if defined(HAVE_GNUTLS)
1968
    if (params.d_provider == "gnutls") {
1969
      return std::make_shared<GnuTLSIOCtx>(params);
1970
    }
1971
#endif /* HAVE_GNUTLS */
1972
#if defined(HAVE_LIBSSL)
×
1973
    if (params.d_provider == "openssl") {
×
1974
      return OpenSSLTLSIOCtx::createClientSideContext(params);
×
1975
    }
×
1976
#endif /* HAVE_LIBSSL */
×
1977
  }
×
1978

1979
#if defined(HAVE_LIBSSL)
×
1980
  return OpenSSLTLSIOCtx::createClientSideContext(params);
×
1981
#elif defined(HAVE_GNUTLS)
1982
  return std::make_shared<GnuTLSIOCtx>(params);
1983
#else
1984
#error "DNS over TLS support needed but neither libssl nor GnuTLS were selected"
1985
#endif
1986

1987
#endif /* HAVE_DNS_OVER_TLS */
×
1988
  return nullptr;
×
1989
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc