• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 10789862120

10 Sep 2024 09:38AM UTC coverage: 55.763% (-0.03%) from 55.792%
10789862120

push

github

web-flow
Merge pull request #14641 from rgacogne/ddist19-backport-14573

dnsdist-1.9.x: Backport 14573 - Stop reporting timeouts in `topSlow()`, add `topTimeouts()`

13812 of 45324 branches covered (30.47%)

Branch coverage included in aggregate %.

3 of 9 new or added lines in 1 file covered. (33.33%)

13 existing lines in 4 files now uncovered.

47994 of 65513 relevant lines covered (73.26%)

3755251.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

56.41
/pdns/tcpiohandler.cc
1

2
#include "config.h"
3
#include "dolog.hh"
4
#include "iputils.hh"
5
#include "lock.hh"
6
#include "tcpiohandler.hh"
7

8
const bool TCPIOHandler::s_disableConnectForUnitTests = false;
9

10
#ifdef HAVE_LIBSODIUM
11
#include <sodium.h>
12
#endif /* HAVE_LIBSODIUM */
13

14
TLSCtx::tickets_key_added_hook TLSCtx::s_ticketsKeyAddedHook{nullptr};
15

16
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
17
#ifdef HAVE_LIBSSL
18

19
#include <openssl/conf.h>
20
#include <openssl/err.h>
21
#include <openssl/rand.h>
22
#include <openssl/ssl.h>
23
#include <openssl/x509v3.h>
24

25
#include "libssl.hh"
26

27

28
class OpenSSLFrontendContext
29
{
30
public:
31
  OpenSSLFrontendContext(const ComboAddress& addr, const TLSConfig& tlsConfig): d_ticketKeys(tlsConfig.d_numberOfTicketsKeys)
32
  {
49✔
33
    registerOpenSSLUser();
49✔
34

35
    auto [ctx, warnings] = libssl_init_server_context(tlsConfig, d_ocspResponses);
49✔
36
    for (const auto& warning : warnings) {
49✔
37
      warnlog("%s", warning);
2✔
38
    }
2✔
39
    d_tlsCtx = std::move(ctx);
49✔
40

41
    if (!d_tlsCtx) {
49!
42
      ERR_print_errors_fp(stderr);
×
43
      throw std::runtime_error("Error creating TLS context on " + addr.toStringWithPort());
×
44
    }
×
45
  }
49✔
46

47
  void cleanup()
48
  {
×
49
    d_tlsCtx.reset();
×
50

×
51
    unregisterOpenSSLUser();
×
52
  }
×
53

54
  OpenSSLTLSTicketKeysRing d_ticketKeys;
55
  std::map<int, std::string> d_ocspResponses;
56
  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx{nullptr, SSL_CTX_free};
57
  pdns::UniqueFilePtr d_keyLogFile{nullptr};
58
};
59

60
class OpenSSLSession : public TLSSession
61
{
62
public:
63
  OpenSSLSession(std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)>&& sess): d_sess(std::move(sess))
64
  {
80✔
65
  }
80✔
66

67
  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> getNative()
68
  {
42✔
69
    return std::move(d_sess);
42✔
70
  }
42✔
71

72
private:
73
  std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
74
};
75

76
class OpenSSLTLSConnection: public TLSConnection
77
{
78
public:
79
  /* server side connection */
80
  OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(std::move(feContext)), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
81
  {
255✔
82
    d_socket = socket;
255✔
83

84
    if (!d_conn) {
255!
85
      vinfolog("Error creating TLS object");
×
86
      if (g_verbose) {
×
87
        ERR_print_errors_fp(stderr);
×
88
      }
×
89
      throw std::runtime_error("Error creating TLS object");
×
90
    }
×
91

92
    if (!SSL_set_fd(d_conn.get(), d_socket)) {
255!
93
      throw std::runtime_error("Error assigning socket");
×
94
    }
×
95

96
    SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
255✔
97
  }
255✔
98

99
  /* client-side connection */
100
  OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<SSL_CTX>& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout)
101
  {
74✔
102
    d_socket = socket;
74✔
103

104
    if (!d_conn) {
74!
105
      vinfolog("Error creating TLS object");
×
106
      if (g_verbose) {
×
107
        ERR_print_errors_fp(stderr);
×
108
      }
×
109
      throw std::runtime_error("Error creating TLS object");
×
110
    }
×
111

112
    if (!SSL_set_fd(d_conn.get(), d_socket)) {
74!
113
      throw std::runtime_error("Error assigning socket");
×
114
    }
×
115

116
    /* set outgoing Server Name Indication */
117
    if (!d_hostname.empty() && SSL_set_tlsext_host_name(d_conn.get(), d_hostname.c_str()) != 1) {
74!
118
      throw std::runtime_error("Error setting TLS SNI to " + d_hostname);
×
119
    }
×
120

121
    if (hostIsAddr) {
74✔
122
#if (OPENSSL_VERSION_NUMBER >= 0x10002000L)
2✔
123
      X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
2✔
124
      /* Enable automatic IP checks */
125
      X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
2✔
126
      if (X509_VERIFY_PARAM_set1_ip_asc(param, d_hostname.c_str()) != 1) {
2!
127
        throw std::runtime_error("Error setting TLS IP for certificate validation");
×
128
      }
×
129
#else
130
      /* no validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
131
#endif
132
    }
2✔
133
    else {
72✔
134
#if (OPENSSL_VERSION_NUMBER >= 0x1010000fL) && defined(HAVE_SSL_SET_HOSTFLAGS) // grrr libressl
72✔
135
      SSL_set_hostflags(d_conn.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
72✔
136
      if (SSL_set1_host(d_conn.get(), d_hostname.c_str()) != 1) {
72!
137
        throw std::runtime_error("Error setting TLS hostname for certificate validation");
×
138
      }
×
139
#elif (OPENSSL_VERSION_NUMBER >= 0x10002000L)
140
      X509_VERIFY_PARAM *param = SSL_get0_param(d_conn.get());
141
      /* Enable automatic hostname checks */
142
      X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
143
      if (X509_VERIFY_PARAM_set1_host(param, d_hostname.c_str(), d_hostname.size()) != 1) {
144
        throw std::runtime_error("Error setting TLS hostname for certificate validation");
145
      }
146
#else
147
      /* no hostname validation for you, see https://wiki.openssl.org/index.php/Hostname_validation */
148
#endif
149
    }
72✔
150

151
    SSL_set_ex_data(d_conn.get(), getConnectionIndex(), this);
74✔
152
  }
74✔
153

154
  std::vector<int> getAsyncFDs() override
155
  {
105✔
156
    std::vector<int> results;
105✔
157
#ifdef SSL_MODE_ASYNC
105✔
158
    if (SSL_waiting_for_async(d_conn.get()) != 1) {
105!
159
      return results;
105✔
160
    }
105✔
161

162
    OSSL_ASYNC_FD fds[32];
×
163
    size_t numfds = sizeof(fds)/sizeof(*fds);
×
164
    SSL_get_all_async_fds(d_conn.get(), nullptr, &numfds);
×
165
    if (numfds == 0) {
×
166
      return results;
×
167
    }
×
168

169
    SSL_get_all_async_fds(d_conn.get(), fds, &numfds);
×
170
    results.reserve(numfds);
×
171
    for (size_t idx = 0; idx < numfds; idx++) {
×
172
      results.push_back(fds[idx]);
×
173
    }
×
174
#endif
×
175
    return results;
×
176
  }
×
177

178
  IOState convertIORequestToIOState(int res) const
179
  {
1,073✔
180
    int error = SSL_get_error(d_conn.get(), res);
1,073✔
181
    if (error == SSL_ERROR_WANT_READ) {
1,073✔
182
      return IOState::NeedRead;
820✔
183
    }
820✔
184
    else if (error == SSL_ERROR_WANT_WRITE) {
253✔
185
      return IOState::NeedWrite;
2✔
186
    }
2✔
187
    else if (error == SSL_ERROR_SYSCALL) {
251✔
188
      if (errno == 0) {
2!
189
        throw std::runtime_error("TLS connection closed by remote end");
×
190
      }
×
191
      else {
2✔
192
        throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
2✔
193
      }
2✔
194
    }
2✔
195
    else if (error == SSL_ERROR_ZERO_RETURN) {
249✔
196
      throw std::runtime_error("TLS connection closed by remote end");
242✔
197
    }
242✔
198
#ifdef SSL_MODE_ASYNC
7✔
199
    else if (error == SSL_ERROR_WANT_ASYNC) {
7!
200
      return IOState::Async;
×
201
    }
×
202
#endif
7✔
203
    else {
7✔
204
      if (g_verbose) {
7✔
205
        throw std::runtime_error("Error while processing TLS connection: (" + std::to_string(error) + ") " + libssl_get_error_string());
1✔
206
      } else {
6✔
207
        throw std::runtime_error("Error while processing TLS connection: " + std::to_string(error));
6✔
208
      }
6✔
209
    }
7✔
210
  }
1,073✔
211

212
  void handleIORequest(int res, const struct timeval& timeout)
213
  {
5✔
214
    auto state = convertIORequestToIOState(res);
5✔
215
    if (state == IOState::NeedRead) {
5✔
216
      res = waitForData(d_socket, timeout.tv_sec, timeout.tv_usec);
4✔
217
      if (res == 0) {
4!
218
        throw std::runtime_error("Timeout while reading from TLS connection");
×
219
      }
×
220
      else if (res < 0) {
4!
221
        throw std::runtime_error("Error waiting to read from TLS connection");
×
222
      }
×
223
    }
4✔
224
    else if (state == IOState::NeedWrite) {
1!
225
      res = waitForRWData(d_socket, false, timeout.tv_sec, timeout.tv_usec);
×
226
      if (res == 0) {
×
227
        throw std::runtime_error("Timeout while writing to TLS connection");
×
228
      }
×
229
      else if (res < 0) {
×
230
        throw std::runtime_error("Error waiting to write to TLS connection");
×
231
      }
×
232
    }
×
233
  }
5✔
234

235
  IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
236
  {
67✔
237
    /* sorry */
238
    (void) fastOpen;
67✔
239
    (void) remote;
67✔
240

241
    int res = SSL_connect(d_conn.get());
67✔
242
    if (res == 1) {
67!
UNCOV
243
      return IOState::Done;
×
UNCOV
244
    }
×
245
    else if (res < 0) {
67!
246
      return convertIORequestToIOState(res);
67✔
247
    }
67✔
248

249
    throw std::runtime_error("Error establishing a TLS connection");
×
250
  }
67✔
251

252
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval &timeout) override
253
  {
4✔
254
    /* sorry */
255
    (void) fastOpen;
4✔
256
    (void) remote;
4✔
257

258
    struct timeval start{0,0};
4✔
259
    struct timeval remainingTime = timeout;
4✔
260
    if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
4!
261
      gettimeofday(&start, nullptr);
4✔
262
    }
4✔
263

264
    int res = 0;
4✔
265
    do {
8✔
266
      res = SSL_connect(d_conn.get());
8✔
267
      if (res < 0) {
8✔
268
        handleIORequest(res, remainingTime);
5✔
269
      }
5✔
270

271
      if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
8!
272
        struct timeval now;
7✔
273
        gettimeofday(&now, nullptr);
7✔
274
        struct timeval elapsed = now - start;
7✔
275
        if (now < start || remainingTime < elapsed) {
7!
276
          throw runtime_error("Timeout while establishing TLS connection");
×
277
        }
×
278
        start = now;
7✔
279
        remainingTime = remainingTime - elapsed;
7✔
280
      }
7✔
281
    }
8✔
282
    while (res != 1);
8✔
283
  }
4✔
284

285
  IOState tryHandshake() override
286
  {
508✔
287
    if (!d_feContext) {
508!
288
      /* In client mode, the handshake is initiated by the call to SSL_connect()
289
         done from connect()/tryConnect().
290
         In blocking mode it does not return before the handshake has been finished,
291
         and in non-blocking mode calling SSL_connect() once is enough for SSL_write()
292
         and SSL_read() to transparently continue to negotiate the connection after that
293
         (equivalent to doing SSL_set_connect_state() plus trying to write).
294
      */
295
      return IOState::Done;
×
296
    }
×
297

298
    /* As explained above in the client-mode block, we only need to call SSL_accept() once
299
       for SSL_write() and SSL_read() to transparently continue to negotiate the connection after that.
300
       It is equivalent to calling SSL_set_accept_state() plus trying to read.
301
    */
302
    int res = SSL_accept(d_conn.get());
508✔
303
    if (res == 1) {
508✔
304
      return IOState::Done;
253✔
305
    }
253✔
306
    else if (res < 0) {
255!
307
      return convertIORequestToIOState(res);
255✔
308
    }
255✔
309

310
    throw std::runtime_error("Error accepting TLS connection");
×
311
  }
508✔
312

313
  void doHandshake() override
314
  {
×
315
    if (!d_feContext) {
×
316
      /* we are a client, nothing to do, see the non-blocking version */
317
      return;
×
318
    }
×
319

320
    int res = 0;
×
321
    do {
×
322
      res = SSL_accept(d_conn.get());
×
323
      if (res < 0) {
×
324
        handleIORequest(res, d_timeout);
×
325
      }
×
326
    }
×
327
    while (res < 0);
×
328

329
    if (res != 1) {
×
330
      throw std::runtime_error("Error accepting TLS connection");
×
331
    }
×
332
  }
×
333

334
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
335
  {
1,018✔
336
    if (!d_feContext && !d_connected) {
1,018✔
337
      if (d_ktls) {
134!
338
        /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
339
        SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get()));
×
340
      }
×
341
    }
134✔
342

343
    do {
1,018✔
344
      int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
1,018✔
345
      if (res <= 0) {
1,018✔
346
        return convertIORequestToIOState(res);
77✔
347
      }
77✔
348
      else {
941✔
349
        pos += static_cast<size_t>(res);
941✔
350
      }
941✔
351
    }
1,018✔
352
    while (pos < toWrite);
1,018!
353

354
    if (!d_connected) {
941✔
355
      d_connected = true;
281✔
356
    }
281✔
357

358
    return IOState::Done;
941✔
359
  }
1,018✔
360

361
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
362
  {
2,105✔
363
    do {
2,107✔
364
      int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
2,107✔
365
      if (res <= 0) {
2,107✔
366
        return convertIORequestToIOState(res);
667✔
367
      }
667✔
368
      else {
1,440✔
369
        pos += static_cast<size_t>(res);
1,440✔
370
        if (allowIncomplete) {
1,440✔
371
          break;
744✔
372
        }
744✔
373
      }
1,440✔
374
    }
2,107✔
375
    while (pos < toRead);
2,105✔
376
    return IOState::Done;
1,438✔
377
  }
2,105✔
378

379
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
380
  {
×
381
    size_t got = 0;
×
382
    struct timeval start = {0, 0};
×
383
    struct timeval remainingTime = totalTimeout;
×
384
    if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
385
      gettimeofday(&start, nullptr);
×
386
    }
×
387

388
    do {
×
389
      int res = SSL_read(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), static_cast<int>(bufferSize - got));
×
390
      if (res <= 0) {
×
391
        handleIORequest(res, readTimeout);
×
392
      }
×
393
      else {
×
394
        got += static_cast<size_t>(res);
×
395
        if (allowIncomplete) {
×
396
          break;
×
397
        }
×
398
      }
×
399

400
      if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
401
        struct timeval now;
×
402
        gettimeofday(&now, nullptr);
×
403
        struct timeval elapsed = now - start;
×
404
        if (now < start || remainingTime < elapsed) {
×
405
          throw runtime_error("Timeout while reading data");
×
406
        }
×
407
        start = now;
×
408
        remainingTime = remainingTime - elapsed;
×
409
      }
×
410
    }
×
411
    while (got < bufferSize);
×
412

413
    return got;
×
414
  }
×
415

416
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
417
  {
×
418
    size_t got = 0;
×
419
    do {
×
420
      int res = SSL_write(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), static_cast<int>(bufferSize - got));
×
421
      if (res <= 0) {
×
422
        handleIORequest(res, writeTimeout);
×
423
      }
×
424
      else {
×
425
        got += static_cast<size_t>(res);
×
426
      }
×
427
    }
×
428
    while (got < bufferSize);
×
429

430
    return got;
×
431
  }
×
432

433
  bool isUsable() const override
434
  {
2✔
435
    if (!d_conn) {
2!
436
      return false;
×
437
    }
×
438

439
    char buf;
2✔
440
    int res = SSL_peek(d_conn.get(), &buf, sizeof(buf));
2✔
441
    if (res > 0) {
2!
442
      return true;
×
443
    }
×
444
    try {
2✔
445
      convertIORequestToIOState(res);
2✔
446
      return true;
2✔
447
    }
2✔
448
    catch (...) {
2✔
449
      return false;
×
450
    }
×
451

452
    return false;
×
453
  }
2✔
454

455
  void close() override
456
  {
306✔
457
    if (d_conn) {
306!
458
      SSL_shutdown(d_conn.get());
306✔
459
    }
306✔
460
  }
306✔
461

462
  std::string getServerNameIndication() const override
463
  {
410✔
464
    if (d_conn) {
410!
465
      const char* value = SSL_get_servername(d_conn.get(), TLSEXT_NAMETYPE_host_name);
410✔
466
      if (value) {
410!
467
        return std::string(value);
410✔
468
      }
410✔
469
    }
410✔
470
    return std::string();
×
471
  }
410✔
472

473
  std::vector<uint8_t> getNextProtocol() const override
474
  {
148✔
475
    std::vector<uint8_t> result;
148✔
476
    if (!d_conn) {
148!
477
      return result;
×
478
    }
×
479

480
    const unsigned char* alpn = nullptr;
148✔
481
    unsigned int alpnLen  = 0;
148✔
482
#ifndef DISABLE_NPN
148✔
483
#ifdef HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED
148✔
484
    SSL_get0_next_proto_negotiated(d_conn.get(), &alpn, &alpnLen);
148✔
485
#endif /* HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED */
148✔
486
#endif /* DISABLE_NPN */
148✔
487
#ifdef HAVE_SSL_GET0_ALPN_SELECTED
148✔
488
    if (alpn == nullptr) {
148!
489
      SSL_get0_alpn_selected(d_conn.get(), &alpn, &alpnLen);
148✔
490
    }
148✔
491
#endif /* HAVE_SSL_GET0_ALPN_SELECTED */
148✔
492
    if (alpn != nullptr && alpnLen > 0) {
148!
493
      result.insert(result.end(), alpn, alpn + alpnLen);
126✔
494
    }
126✔
495
    return result;
148✔
496
  }
148✔
497

498
  LibsslTLSVersion getTLSVersion() const override
499
  {
413✔
500
    auto proto = SSL_version(d_conn.get());
413✔
501
    switch (proto) {
413✔
502
    case TLS1_VERSION:
×
503
      return LibsslTLSVersion::TLS10;
×
504
    case TLS1_1_VERSION:
×
505
      return LibsslTLSVersion::TLS11;
×
506
    case TLS1_2_VERSION:
6✔
507
      return LibsslTLSVersion::TLS12;
6✔
508
#ifdef TLS1_3_VERSION
×
509
    case TLS1_3_VERSION:
407✔
510
      return LibsslTLSVersion::TLS13;
407✔
511
#endif /* TLS1_3_VERSION */
×
512
    default:
×
513
      return LibsslTLSVersion::Unknown;
×
514
    }
413✔
515
  }
413✔
516

517
  bool hasSessionBeenResumed() const override
518
  {
285✔
519
    if (d_conn) {
285!
520
      return SSL_session_reused(d_conn.get()) != 0;
285✔
521
    }
285✔
522
    return false;
×
523
  }
285✔
524

525
  std::vector<std::unique_ptr<TLSSession>> getSessions() override
526
  {
47✔
527
    return std::move(d_tlsSessions);
47✔
528
  }
47✔
529

530
  void setSession(std::unique_ptr<TLSSession>& session) override
531
  {
42✔
532
    auto sess = dynamic_cast<OpenSSLSession*>(session.get());
42✔
533
    if (!sess) {
42!
534
      throw std::runtime_error("Unable to convert OpenSSL session");
×
535
    }
×
536

537
    auto native = sess->getNative();
42✔
538
    auto ret = SSL_set_session(d_conn.get(), native.get());
42✔
539
    if (ret != 1) {
42!
540
      throw std::runtime_error("Error setting up session: " + libssl_get_error_string());
×
541
    }
×
542
    session.reset();
42✔
543
  }
42✔
544

545
  void addNewTicket(SSL_SESSION* session)
546
  {
80✔
547
    d_tlsSessions.push_back(std::make_unique<OpenSSLSession>(std::unique_ptr<SSL_SESSION, void (*)(SSL_SESSION*)>(session, SSL_SESSION_free)));
80✔
548
  }
80✔
549

550
  void enableKTLS()
551
  {
×
552
    d_ktls = true;
×
553
  }
×
554

555
  static void generateConnectionIndexIfNeeded()
556
  {
74✔
557
    auto init = s_initTLSConnIndex.lock();
74✔
558
    if (*init == true) {
74✔
559
      return;
21✔
560
    }
21✔
561

562
    /* not initialized yet */
563
    s_tlsConnIndex = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
53✔
564
    if (s_tlsConnIndex == -1) {
53!
565
      throw std::runtime_error("Error getting an index for TLS connection data");
×
566
    }
×
567

568
    *init = true;
53✔
569
  }
53✔
570

571
  static int getConnectionIndex()
572
  {
423✔
573
    return s_tlsConnIndex;
423✔
574
  }
423✔
575

576
private:
577
  static LockGuarded<bool> s_initTLSConnIndex;
578
  static int s_tlsConnIndex;
579
  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
580
  /* server context */
581
  std::shared_ptr<OpenSSLFrontendContext> d_feContext;
582
  /* client context */
583
  std::shared_ptr<SSL_CTX> d_tlsCtx;
584
  std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
585
  std::string d_hostname;
586
  struct timeval d_timeout;
587
  bool d_connected{false};
588
  bool d_ktls{false};
589
};
590

591
LockGuarded<bool> OpenSSLTLSConnection::s_initTLSConnIndex{false};
592
int OpenSSLTLSConnection::s_tlsConnIndex{-1};
593

594
class OpenSSLTLSIOCtx: public TLSCtx
595
{
596
public:
597
  /* server side context */
598
  OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
599
  {
49✔
600
    OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
49✔
601

602
    d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
49✔
603

604
    if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
49!
605
      /* use our own ticket keys handler so we can rotate them */
606
#if OPENSSL_VERSION_MAJOR >= 3
47✔
607
      SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
47✔
608
#else
609
      SSL_CTX_set_tlsext_ticket_key_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
610
#endif
611
      libssl_set_ticket_key_callback_data(d_feContext->d_tlsCtx.get(), d_feContext.get());
47✔
612
    }
47✔
613

614
#ifndef DISABLE_OCSP_STAPLING
49✔
615
    if (!d_feContext->d_ocspResponses.empty()) {
49✔
616
      SSL_CTX_set_tlsext_status_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ocspStaplingCb);
5✔
617
      SSL_CTX_set_tlsext_status_arg(d_feContext->d_tlsCtx.get(), &d_feContext->d_ocspResponses);
5✔
618
    }
5✔
619
#endif /* DISABLE_OCSP_STAPLING */
49✔
620

621
    if (fe.d_tlsConfig.d_readAhead) {
49!
622
      SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1);
49✔
623
    }
49✔
624

625
    libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
49✔
626

627
    if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
49!
628
      d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile);
×
629
    }
×
630

631
    try {
49✔
632
      if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
49!
633
        handleTicketsKeyRotation(time(nullptr));
49✔
634
      }
49✔
635
      else {
×
636
        OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
×
637
      }
×
638
    }
49✔
639
    catch (const std::exception& e) {
49✔
640
      throw;
×
641
    }
×
642
  }
49✔
643

644
  /* client side context */
645
  OpenSSLTLSIOCtx(const TLSContextParameters& params)
646
  {
25✔
647
    int sslOptions =
25✔
648
      SSL_OP_NO_SSLv2 |
25✔
649
      SSL_OP_NO_SSLv3 |
25✔
650
      SSL_OP_NO_COMPRESSION |
25✔
651
      SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION |
25✔
652
      SSL_OP_SINGLE_DH_USE |
25✔
653
      SSL_OP_SINGLE_ECDH_USE |
25✔
654
#ifdef SSL_OP_IGNORE_UNEXPECTED_EOF
25✔
655
      SSL_OP_IGNORE_UNEXPECTED_EOF |
25✔
656
#endif
25✔
657
      SSL_OP_CIPHER_SERVER_PREFERENCE;
25✔
658
    if (!params.d_enableRenegotiation) {
25!
659
#ifdef SSL_OP_NO_RENEGOTIATION
25✔
660
      sslOptions |= SSL_OP_NO_RENEGOTIATION;
25✔
661
#elif defined(SSL_OP_NO_CLIENT_RENEGOTIATION)
662
      sslOptions |= SSL_OP_NO_CLIENT_RENEGOTIATION;
663
#endif
664
    }
25✔
665

666
    if (params.d_ktls) {
25!
667
#ifdef SSL_OP_ENABLE_KTLS
×
668
      sslOptions |= SSL_OP_ENABLE_KTLS;
×
669
      d_ktls = true;
×
670
#endif /* SSL_OP_ENABLE_KTLS */
×
671
    }
×
672

673
    registerOpenSSLUser();
25✔
674

675
    OpenSSLTLSConnection::generateConnectionIndexIfNeeded();
25✔
676

677
#ifdef HAVE_TLS_CLIENT_METHOD
678
    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(TLS_client_method()), SSL_CTX_free);
679
#else
680
    d_tlsCtx = std::shared_ptr<SSL_CTX>(SSL_CTX_new(SSLv23_client_method()), SSL_CTX_free);
25✔
681
#endif
25✔
682
    if (!d_tlsCtx) {
25!
683
      ERR_print_errors_fp(stderr);
×
684
      throw std::runtime_error("Error creating TLS context");
×
685
    }
×
686

687
    SSL_CTX_set_options(d_tlsCtx.get(), sslOptions);
25✔
688
#if defined(SSL_CTX_set_ecdh_auto)
25✔
689
    SSL_CTX_set_ecdh_auto(d_tlsCtx.get(), 1);
25✔
690
#endif
25✔
691

692
    if (!params.d_ciphers.empty()) {
25!
693
      if (SSL_CTX_set_cipher_list(d_tlsCtx.get(), params.d_ciphers.c_str()) != 1) {
×
694
        ERR_print_errors_fp(stderr);
×
695
        throw std::runtime_error("Error setting the cipher list to '" + params.d_ciphers + "' for the TLS context");
×
696
      }
×
697
    }
×
698
#ifdef HAVE_SSL_CTX_SET_CIPHERSUITES
25✔
699
    if (!params.d_ciphers13.empty()) {
25!
700
      if (SSL_CTX_set_ciphersuites(d_tlsCtx.get(), params.d_ciphers13.c_str()) != 1) {
×
701
        ERR_print_errors_fp(stderr);
×
702
        throw std::runtime_error("Error setting the TLS 1.3 cipher list to '" + params.d_ciphers13 + "' for the TLS context");
×
703
      }
×
704
    }
×
705
#endif /* HAVE_SSL_CTX_SET_CIPHERSUITES */
25✔
706

707
    if (params.d_validateCertificates) {
25✔
708
      if (params.d_caStore.empty())  {
21!
709
        if (SSL_CTX_set_default_verify_paths(d_tlsCtx.get()) != 1) {
×
710
          throw std::runtime_error("Error adding the system's default trusted CAs");
×
711
        }
×
712
      } else {
21✔
713
        if (SSL_CTX_load_verify_locations(d_tlsCtx.get(), params.d_caStore.c_str(), nullptr) != 1) {
21!
714
          throw std::runtime_error("Error adding the trusted CAs file " + params.d_caStore);
×
715
        }
×
716
      }
21✔
717

718
      SSL_CTX_set_verify(d_tlsCtx.get(), SSL_VERIFY_PEER, nullptr);
21✔
719
#if (OPENSSL_VERSION_NUMBER < 0x10002000L)
720
      warnlog("TLS hostname validation requested but not supported for OpenSSL < 1.0.2");
721
#endif
722
    }
21✔
723

724
    /* we need to set SSL_SESS_CACHE_CLIENT for the "new ticket" callback (below) to be called,
725
       but we don't want OpenSSL to cache the session itself so we set SSL_SESS_CACHE_NO_INTERNAL_STORE as well */
726
    SSL_CTX_set_session_cache_mode(d_tlsCtx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE);
25✔
727
    SSL_CTX_sess_set_new_cb(d_tlsCtx.get(), &OpenSSLTLSIOCtx::newTicketFromServerCb);
25✔
728

729
#ifdef SSL_MODE_RELEASE_BUFFERS
25✔
730
    if (params.d_releaseBuffers) {
25!
731
      SSL_CTX_set_mode(d_tlsCtx.get(), SSL_MODE_RELEASE_BUFFERS);
25✔
732
    }
25✔
733
#endif
25✔
734
  }
25✔
735

736
  ~OpenSSLTLSIOCtx() override
737
  {
9✔
738
    d_tlsCtx.reset();
9✔
739
    unregisterOpenSSLUser();
9✔
740
  }
9✔
741

742
#if OPENSSL_VERSION_MAJOR >= 3
743
  static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx, int enc)
744
#else
745
  static int ticketKeyCb(SSL* s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* iv, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx, int enc)
746
#endif
747
  {
484✔
748
    auto* ctx = reinterpret_cast<OpenSSLFrontendContext*>(libssl_get_ticket_key_callback_data(s));
484✔
749
    if (ctx == nullptr) {
484!
750
      return -1;
×
751
    }
×
752

753
    int ret = libssl_ticket_key_callback(s, ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc);
484✔
754
    if (enc == 0) {
484✔
755
      if (ret == 0 || ret == 2) {
30✔
756
        auto* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(s, OpenSSLTLSConnection::getConnectionIndex()));
14✔
757
        if (conn != nullptr) {
14!
758
          if (ret == 0) {
14✔
759
            conn->setUnknownTicketKey();
6✔
760
          }
6✔
761
          else if (ret == 2) {
8!
762
            conn->setResumedFromInactiveTicketKey();
8✔
763
          }
8✔
764
        }
14✔
765
      }
14✔
766
    }
30✔
767

768
    return ret;
484✔
769
  }
484✔
770

771
#ifndef DISABLE_OCSP_STAPLING
772
  static int ocspStaplingCb(SSL* ssl, void* arg)
773
  {
4✔
774
    if (ssl == nullptr || arg == nullptr) {
4!
775
      return SSL_TLSEXT_ERR_NOACK;
×
776
    }
×
777
    const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg);
4✔
778
    return libssl_ocsp_stapling_callback(ssl, *ocspMap);
4✔
779
  }
4✔
780
#endif /* DISABLE_OCSP_STAPLING */
781

782
  static int newTicketFromServerCb(SSL* ssl, SSL_SESSION* session)
783
  {
80✔
784
    OpenSSLTLSConnection* conn = reinterpret_cast<OpenSSLTLSConnection*>(SSL_get_ex_data(ssl, OpenSSLTLSConnection::getConnectionIndex()));
80✔
785
    if (session == nullptr || conn == nullptr) {
80!
786
      return 0;
×
787
    }
×
788

789
    conn->addNewTicket(session);
80✔
790
    return 1;
80✔
791
  }
80✔
792

793
  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
794
  {
255✔
795
    handleTicketsKeyRotation(now);
255✔
796

797
    return std::make_unique<OpenSSLTLSConnection>(socket, timeout, d_feContext);
255✔
798
  }
255✔
799

800
  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
801
  {
74✔
802
    auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, d_tlsCtx);
74✔
803
    if (d_ktls) {
74!
804
      conn->enableKTLS();
×
805
    }
×
806
    return conn;
74✔
807
  }
74✔
808

809
  void rotateTicketsKey(time_t now) override
810
  {
94✔
811
    d_feContext->d_ticketKeys.rotateTicketsKey(now);
94✔
812

813
    if (d_ticketsKeyRotationDelay > 0) {
94!
814
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
94✔
815
    }
94✔
816
  }
94✔
817

818
  void loadTicketsKeys(const std::string& keyFile) final
819
  {
12✔
820
    d_feContext->d_ticketKeys.loadTicketsKeys(keyFile);
12✔
821

822
    if (d_ticketsKeyRotationDelay > 0) {
12!
823
      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
12✔
824
    }
12✔
825
  }
12✔
826

827
  size_t getTicketsKeysCount() override
828
  {
6✔
829
    return d_feContext->d_ticketKeys.getKeysCount();
6✔
830
  }
6✔
831

832
  std::string getName() const override
833
  {
3✔
834
    return "openssl";
3✔
835
  }
3✔
836

837
  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
838
  {
74✔
839
    if (d_feContext && d_feContext->d_tlsCtx) {
74!
840
      d_alpnProtos = protos;
49✔
841
      libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
49✔
842
      return true;
49✔
843
    }
49✔
844
    if (d_tlsCtx) {
25!
845
      return libssl_set_alpn_protos(d_tlsCtx.get(), protos);
25✔
846
    }
25✔
847
    return false;
×
848
  }
25✔
849

850
#ifndef DISABLE_NPN
851
  bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override
852
  {
12✔
853
    d_nextProtocolSelectCallback = cb;
12✔
854
    libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this);
12✔
855
    return true;
12✔
856
  }
12✔
857
#endif /* DISABLE_NPN */
858

859
private:
860
  /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
861
#ifndef DISABLE_NPN
862
  static int npnSelectCallback(SSL* /* s */, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
863
  {
×
864
    if (!arg) {
×
865
      return SSL_TLSEXT_ERR_ALERT_WARNING;
×
866
    }
×
867
    OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
×
868
    if (obj->d_nextProtocolSelectCallback) {
×
869
      return (*obj->d_nextProtocolSelectCallback)(out, outlen, in, inlen) ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_ALERT_WARNING;
×
870
    }
×
871

872
    return SSL_TLSEXT_ERR_OK;
×
873
  }
×
874
#endif /* NPN */
875

876
  static int alpnServerSelectCallback(SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
877
  {
128✔
878
    if (!arg) {
128!
879
      return SSL_TLSEXT_ERR_ALERT_WARNING;
×
880
    }
×
881
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): OpenSSL's API
882
    OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
128✔
883

884
    const pdns::views::UnsignedCharView inView(in, inlen);
128✔
885
    // Server preference algorithm as per RFC 7301 section 3.2
886
    for (const auto& tentative : obj->d_alpnProtos) {
130✔
887
      size_t pos = 0;
130✔
888
      while (pos < inView.size()) {
135✔
889
        size_t protoLen = inView.at(pos);
131✔
890
        pos++;
131✔
891
        if (protoLen > (inlen - pos)) {
131!
892
          /* something is very wrong */
893
          return SSL_TLSEXT_ERR_ALERT_WARNING;
×
894
        }
×
895

896
        if (tentative.size() == protoLen && memcmp(&inView.at(pos), tentative.data(), tentative.size()) == 0) {
131!
897
          *out = &inView.at(pos);
126✔
898
          *outlen = protoLen;
126✔
899
          return SSL_TLSEXT_ERR_OK;
126✔
900
        }
126✔
901
        pos += protoLen;
5✔
902
      }
5✔
903
    }
130✔
904

905
    return SSL_TLSEXT_ERR_NOACK;
2✔
906
  }
128✔
907

908
  std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
909
  std::shared_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
910
  std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
911
  bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr};
912
  bool d_ktls{false};
913
};
914

915
#endif /* HAVE_LIBSSL */
916

917
#ifdef HAVE_GNUTLS
918
#include <gnutls/gnutls.h>
919
#include <gnutls/x509.h>
920

921
static void safe_memory_lock(void* data, size_t size)
922
{
5✔
923
#ifdef HAVE_LIBSODIUM
5✔
924
  sodium_mlock(data, size);
5✔
925
#endif
5✔
926
}
5✔
927

928
static void safe_memory_release(void* data, size_t size)
929
{
33✔
930
#ifdef HAVE_LIBSODIUM
33✔
931
  sodium_munlock(data, size);
33✔
932
#elif defined(HAVE_EXPLICIT_BZERO)
933
  explicit_bzero(data, size);
934
#elif defined(HAVE_EXPLICIT_MEMSET)
935
  explicit_memset(data, 0, size);
936
#elif defined(HAVE_GNUTLS_MEMSET)
937
  gnutls_memset(data, 0, size);
938
#else
939
  /* shamelessly taken from Dovecot's src/lib/safe-memset.c */
940
  volatile unsigned int volatile_zero_idx = 0;
941
  volatile unsigned char *p = reinterpret_cast<volatile unsigned char *>(data);
942

943
  if (size == 0)
944
    return;
945

946
  do {
947
    memset(data, 0, size);
948
  } while (p[volatile_zero_idx] != 0);
949
#endif
950
}
33✔
951

952
class GnuTLSTicketsKey
953
{
954
public:
955
  GnuTLSTicketsKey()
956
  {
5✔
957
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
5!
958
      throw std::runtime_error("Error generating tickets key for TLS context");
×
959
    }
×
960

961
    safe_memory_lock(d_key.data, d_key.size);
5✔
962
  }
5✔
963

964
  GnuTLSTicketsKey(const std::string& keyFile)
965
  {
×
966
    /* to be sure we are loading the correct amount of data, which
967
       may change between versions, let's generate a correct key first */
968
    if (gnutls_session_ticket_key_generate(&d_key) != GNUTLS_E_SUCCESS) {
×
969
      throw std::runtime_error("Error generating tickets key (before parsing key file) for TLS context");
×
970
    }
×
971

972
    safe_memory_lock(d_key.data, d_key.size);
×
973

974
    try {
×
975
      ifstream file(keyFile);
×
976
      file.read(reinterpret_cast<char*>(d_key.data), d_key.size);
×
977

978
      if (file.fail()) {
×
979
        file.close();
×
980
        throw std::runtime_error("Invalid GnuTLS tickets key file " + keyFile);
×
981
      }
×
982

983
      file.close();
×
984
    }
×
985
    catch (const std::exception& e) {
×
986
      safe_memory_release(d_key.data, d_key.size);
×
987
      gnutls_free(d_key.data);
×
988
      d_key.data = nullptr;
×
989
      throw;
×
990
    }
×
991
  }
×
992
  [[nodiscard]] std::string content() const
993
  {
×
994
    std::string result{};
×
995
    if (d_key.data != nullptr && d_key.size > 0) {
×
996
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
997
      result.append(reinterpret_cast<const char*>(d_key.data), d_key.size);
×
998
      safe_memory_lock(result.data(), result.size());
×
999
    }
×
1000
    return result;
×
1001
  }
×
1002

1003
  ~GnuTLSTicketsKey()
1004
  {
2✔
1005
    if (d_key.data != nullptr && d_key.size > 0) {
2!
1006
      safe_memory_release(d_key.data, d_key.size);
2✔
1007
    }
2✔
1008
    gnutls_free(d_key.data);
2✔
1009
    d_key.data = nullptr;
2✔
1010
  }
2✔
1011
  const gnutls_datum_t& getKey() const
1012
  {
14✔
1013
    return d_key;
14✔
1014
  }
14✔
1015

1016
private:
1017
  gnutls_datum_t d_key{nullptr, 0};
1018
};
1019

1020
class GnuTLSSession : public TLSSession
1021
{
1022
public:
1023
  GnuTLSSession(gnutls_datum_t& sess): d_sess(sess)
1024
  {
49✔
1025
    sess.data = nullptr;
49✔
1026
    sess.size = 0;
49✔
1027
  }
49✔
1028

1029
  ~GnuTLSSession() override
1030
  {
31✔
1031
    if (d_sess.data != nullptr && d_sess.size > 0) {
31!
1032
      safe_memory_release(d_sess.data, d_sess.size);
31✔
1033
    }
31✔
1034
    gnutls_free(d_sess.data);
31✔
1035
    d_sess.data = nullptr;
31✔
1036
  }
31✔
1037

1038
  const gnutls_datum_t& getNative()
1039
  {
31✔
1040
    return d_sess;
31✔
1041
  }
31✔
1042

1043
private:
1044
  gnutls_datum_t d_sess{nullptr, 0};
1045
};
1046

1047
class GnuTLSConnection: public TLSConnection
1048
{
1049
public:
1050
  /* server side connection */
1051
  GnuTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_creds(creds), d_ticketsKey(ticketsKey), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit))
1052
  {
14✔
1053
    unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
14✔
1054
#ifdef GNUTLS_NO_SIGNAL
14✔
1055
    sslOptions |= GNUTLS_NO_SIGNAL;
14✔
1056
#endif
14✔
1057

1058
    d_socket = socket;
14✔
1059

1060
    gnutls_session_t conn;
14✔
1061
    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
14!
1062
      throw std::runtime_error("Error creating TLS connection");
×
1063
    }
×
1064

1065
    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
14✔
1066
    conn = nullptr;
14✔
1067

1068
    if (gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get()) != GNUTLS_E_SUCCESS) {
14!
1069
      throw std::runtime_error("Error setting certificate and key to TLS connection");
×
1070
    }
×
1071

1072
    if (gnutls_priority_set(d_conn.get(), priorityCache) != GNUTLS_E_SUCCESS) {
14!
1073
      throw std::runtime_error("Error setting ciphers to TLS connection");
×
1074
    }
×
1075

1076
    if (enableTickets && d_ticketsKey) {
14!
1077
      const gnutls_datum_t& key = d_ticketsKey->getKey();
14✔
1078
      if (gnutls_session_ticket_enable_server(d_conn.get(), &key) != GNUTLS_E_SUCCESS) {
14!
1079
        throw std::runtime_error("Error setting the tickets key to TLS connection");
×
1080
      }
×
1081
    }
14✔
1082

1083
    gnutls_transport_set_int(d_conn.get(), d_socket);
14✔
1084

1085
    /* timeouts are in milliseconds */
1086
    gnutls_handshake_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
14✔
1087
    gnutls_record_set_timeout(d_conn.get(), timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
14✔
1088
  }
14✔
1089

1090
  /* client-side connection */
1091
  GnuTLSConnection(const std::string& host, int socket, const struct timeval& timeout, std::shared_ptr<gnutls_certificate_credentials_st>& creds, const gnutls_priority_t priorityCache, bool validateCerts): d_creds(creds), d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_host(host), d_client(true)
1092
  {
46✔
1093
    unsigned int sslOptions = GNUTLS_CLIENT | GNUTLS_NONBLOCK;
46✔
1094
#ifdef GNUTLS_NO_SIGNAL
46✔
1095
    sslOptions |= GNUTLS_NO_SIGNAL;
46✔
1096
#endif
46✔
1097

1098
    d_socket = socket;
46✔
1099

1100
    gnutls_session_t conn;
46✔
1101
    if (gnutls_init(&conn, sslOptions) != GNUTLS_E_SUCCESS) {
46!
1102
      throw std::runtime_error("Error creating TLS connection");
×
1103
    }
×
1104

1105
    d_conn = std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(conn, gnutls_deinit);
46✔
1106
    conn = nullptr;
46✔
1107

1108
    int rc = gnutls_credentials_set(d_conn.get(), GNUTLS_CRD_CERTIFICATE, d_creds.get());
46✔
1109
    if (rc != GNUTLS_E_SUCCESS) {
46!
1110
      throw std::runtime_error("Error setting certificate and key to TLS connection: " + std::string(gnutls_strerror(rc)));
×
1111
    }
×
1112

1113
    rc = gnutls_priority_set(d_conn.get(), priorityCache);
46✔
1114
    if (rc != GNUTLS_E_SUCCESS) {
46!
1115
      throw std::runtime_error("Error setting ciphers to TLS connection: " + std::string(gnutls_strerror(rc)));
×
1116
    }
×
1117

1118
    gnutls_transport_set_int(d_conn.get(), d_socket);
46✔
1119

1120
    /* timeouts are in milliseconds */
1121
    gnutls_handshake_set_timeout(d_conn.get(),  timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
46✔
1122
    gnutls_record_set_timeout(d_conn.get(),  timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
46✔
1123

1124
#ifdef HAVE_GNUTLS_SESSION_SET_VERIFY_CERT
46✔
1125
    if (validateCerts && !d_host.empty()) {
46!
1126
      gnutls_session_set_verify_cert(d_conn.get(), d_host.c_str(), GNUTLS_VERIFY_ALLOW_UNSORTED_CHAIN);
30✔
1127
      rc = gnutls_server_name_set(d_conn.get(), GNUTLS_NAME_DNS, d_host.c_str(), d_host.size());
30✔
1128
      if (rc != GNUTLS_E_SUCCESS) {
30!
1129
        throw std::runtime_error("Error setting the SNI value to '" + d_host + "' on TLS connection: " + std::string(gnutls_strerror(rc)));
×
1130
      }
×
1131
    }
30✔
1132
#else
1133
    /* no hostname validation for you */
1134
#endif
1135

1136
    /* allow access to our data in the callbacks */
1137
    gnutls_session_set_ptr(d_conn.get(), this);
46✔
1138
    gnutls_handshake_set_hook_function(d_conn.get(), GNUTLS_HANDSHAKE_NEW_SESSION_TICKET, GNUTLS_HOOK_POST, newTicketFromServerCb);
46✔
1139
  }
46✔
1140

1141
  /* The callback prototype changed in 3.4.0. */
1142
#if GNUTLS_VERSION_NUMBER >= 0x030400
1143
  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */, const gnutls_datum_t* /* msg */)
1144
#else
1145
  static int newTicketFromServerCb(gnutls_session_t session, unsigned int htype, unsigned post, unsigned int /* incoming */)
1146
#endif /* GNUTLS_VERSION_NUMBER >= 0x030400 */
1147
  {
49✔
1148
    if (htype != GNUTLS_HANDSHAKE_NEW_SESSION_TICKET || post != GNUTLS_HOOK_POST || session == nullptr) {
49!
1149
      return 0;
×
1150
    }
×
1151

1152
    GnuTLSConnection* conn = reinterpret_cast<GnuTLSConnection*>(gnutls_session_get_ptr(session));
49✔
1153
    if (conn == nullptr) {
49!
1154
      return 0;
×
1155
    }
×
1156

1157
    gnutls_datum_t sess{nullptr, 0};
49✔
1158
    auto ret = gnutls_session_get_data2(session, &sess);
49✔
1159
    /* GnuTLS returns a 'fake' ticket of 4 bytes set to zero when there is no ticket available */
1160
    if (ret != GNUTLS_E_SUCCESS || sess.size <= 4) {
49!
1161
      throw std::runtime_error("Error getting GnuTLSSession: " + std::string(gnutls_strerror(ret)));
×
1162
    }
×
1163
    conn->d_tlsSessions.push_back(std::make_unique<GnuTLSSession>(sess));
49✔
1164
    return 0;
49✔
1165
  }
49✔
1166

1167
  IOState tryConnect(bool fastOpen, [[maybe_unused]] const ComboAddress& remote) override
1168
  {
46✔
1169
    int ret = 0;
46✔
1170

1171
    if (fastOpen) {
46!
1172
#ifdef HAVE_GNUTLS_TRANSPORT_SET_FASTOPEN
1173
      gnutls_transport_set_fastopen(d_conn.get(), d_socket, const_cast<struct sockaddr*>(reinterpret_cast<const struct sockaddr*>(&remote)), remote.getSocklen(), 0);
1174
#endif
1175
    }
×
1176

1177
    do {
46✔
1178
      ret = gnutls_handshake(d_conn.get());
46✔
1179
      if (ret == GNUTLS_E_SUCCESS) {
46!
1180
        d_handshakeDone = true;
×
1181
        return IOState::Done;
×
1182
      }
×
1183
      else if (ret == GNUTLS_E_AGAIN) {
46!
1184
        int direction = gnutls_record_get_direction(d_conn.get());
46✔
1185
        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
46!
1186
      }
46✔
1187
      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
×
1188
        throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
×
1189
      }
×
1190
    } while (ret == GNUTLS_E_INTERRUPTED);
46!
1191

1192
    throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
×
1193
  }
46✔
1194

1195
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override
1196
  {
×
1197
    struct timeval start = {0, 0};
×
1198
    struct timeval remainingTime = timeout;
×
1199
    if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
×
1200
      gettimeofday(&start, nullptr);
×
1201
    }
×
1202

1203
    IOState state;
×
1204
    do {
×
1205
      state = tryConnect(fastOpen, remote);
×
1206
      if (state == IOState::Done) {
×
1207
        return;
×
1208
      }
×
1209
      else if (state == IOState::NeedRead) {
×
1210
        int result = waitForData(d_socket, remainingTime.tv_sec, remainingTime.tv_usec);
×
1211
        if (result <= 0) {
×
1212
          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
×
1213
        }
×
1214
      }
×
1215
      else if (state == IOState::NeedWrite) {
×
1216
        int result = waitForRWData(d_socket, false, remainingTime.tv_sec, remainingTime.tv_usec);
×
1217
        if (result <= 0) {
×
1218
          throw std::runtime_error("Error reading from TLS connection: " + std::to_string(result));
×
1219
        }
×
1220
      }
×
1221

1222
      if (timeout.tv_sec != 0 || timeout.tv_usec != 0) {
×
1223
        struct timeval now;
×
1224
        gettimeofday(&now, nullptr);
×
1225
        struct timeval elapsed = now - start;
×
1226
        if (now < start || remainingTime < elapsed) {
×
1227
          throw runtime_error("Timeout while establishing TLS connection");
×
1228
        }
×
1229
        start = now;
×
1230
        remainingTime = remainingTime - elapsed;
×
1231
      }
×
1232
    }
×
1233
    while (state != IOState::Done);
×
1234
  }
×
1235

1236
  void doHandshake() override
1237
  {
×
1238
    int ret = 0;
×
1239
    do {
×
1240
      ret = gnutls_handshake(d_conn.get());
×
1241
      if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
×
1242
        if (d_client) {
×
1243
          throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
×
1244
        }
×
1245
        else {
×
1246
          throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
×
1247
        }
×
1248
      }
×
1249
    }
×
1250
    while (ret != GNUTLS_E_SUCCESS && ret == GNUTLS_E_INTERRUPTED);
×
1251

1252
    d_handshakeDone = true;
×
1253
  }
×
1254

1255
  IOState tryHandshake() override
1256
  {
120✔
1257
    int ret = 0;
120✔
1258

1259
    do {
120✔
1260
      ret = gnutls_handshake(d_conn.get());
120✔
1261
      if (ret == GNUTLS_E_SUCCESS) {
120✔
1262
        d_handshakeDone = true;
54✔
1263
        return IOState::Done;
54✔
1264
      }
54✔
1265
      else if (ret == GNUTLS_E_AGAIN) {
66✔
1266
        int direction = gnutls_record_get_direction(d_conn.get());
60✔
1267
        return direction == 0 ? IOState::NeedRead : IOState::NeedWrite;
60!
1268
      }
60✔
1269
      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
6!
1270
        if (d_client) {
6!
1271
          std::string error;
6✔
1272
#ifdef HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS
6✔
1273
          if (ret == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) {
6!
1274
            gnutls_datum_t out;
6✔
1275
            if (gnutls_certificate_verification_status_print(gnutls_session_get_verify_cert_status(d_conn.get()), gnutls_certificate_type_get(d_conn.get()), &out, 0) == 0) {
6!
1276
              error = " (" + std::string(reinterpret_cast<const char*>(out.data)) + ")";
6✔
1277
              gnutls_free(out.data);
6✔
1278
            }
6✔
1279
          }
6✔
1280
#endif /* HAVE_GNUTLS_SESSION_GET_VERIFY_CERT_STATUS */
6✔
1281
          throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)) + error);
6✔
1282
        }
6✔
1283
        else {
×
1284
          throw std::runtime_error("Error establishing a new connection: " + std::string(gnutls_strerror(ret)));
×
1285
        }
×
1286
      }
6✔
1287
    } while (ret == GNUTLS_E_INTERRUPTED);
120!
1288

1289
    if (d_client) {
×
1290
      throw std::runtime_error("Error establishinging a new connection: " + std::string(gnutls_strerror(ret)));
×
1291
    }
×
1292
    else {
×
1293
      throw std::runtime_error("Error accepting a new connection: " + std::string(gnutls_strerror(ret)));
×
1294
    }
×
1295
  }
×
1296

1297
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
1298
  {
298✔
1299
    if (!d_handshakeDone) {
298✔
1300
      /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1301
         we need to keep calling gnutls_handshake() until the handshake has been finished. */
1302
      auto state = tryHandshake();
92✔
1303
      if (state != IOState::Done) {
92✔
1304
        return state;
46✔
1305
      }
46✔
1306
    }
92✔
1307

1308
    do {
252✔
1309
      ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
252✔
1310
      if (res == 0) {
252!
1311
        throw std::runtime_error("Error writing to TLS connection");
×
1312
      }
×
1313
      else if (res > 0) {
252✔
1314
        pos += static_cast<size_t>(res);
246✔
1315
      }
246✔
1316
      else if (res < 0) {
6!
1317
        if (gnutls_error_is_fatal(res)) {
×
1318
          throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
×
1319
        }
×
1320
        else if (res == GNUTLS_E_AGAIN) {
×
1321
          return IOState::NeedWrite;
×
1322
        }
×
1323
        vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1324
      }
×
1325
    }
252✔
1326
    while (pos < toWrite);
252!
1327
    return IOState::Done;
252✔
1328
  }
252✔
1329

1330
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete) override
1331
  {
501✔
1332
    if (!d_handshakeDone) {
501!
1333
      /* As opposed to OpenSSL, GnuTLS will not transparently finish the handshake for us,
1334
         we need to keep calling gnutls_handshake() until the handshake has been finished. */
1335
      auto state = tryHandshake();
×
1336
      if (state != IOState::Done) {
×
1337
        return state;
×
1338
      }
×
1339
    }
×
1340

1341
    do {
501✔
1342
      ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
501✔
1343
      if (res == 0) {
501✔
1344
        throw std::runtime_error("EOF while reading from TLS connection");
3✔
1345
      }
3✔
1346
      else if (res > 0) {
498✔
1347
        pos += static_cast<size_t>(res);
346✔
1348
        if (allowIncomplete) {
346✔
1349
          break;
70✔
1350
        }
70✔
1351
      }
346✔
1352
      else if (res < 0) {
152!
1353
        if (gnutls_error_is_fatal(res)) {
152✔
1354
          throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
33✔
1355
        }
33✔
1356
        else if (res == GNUTLS_E_AGAIN) {
119!
1357
          return IOState::NeedRead;
119✔
1358
        }
119✔
1359
        vinfolog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1360
      }
×
1361
    }
501✔
1362
    while (pos < toRead);
501!
1363
    return IOState::Done;
346✔
1364
  }
501✔
1365

1366
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout, bool allowIncomplete) override
1367
  {
×
1368
    size_t got = 0;
×
1369
    struct timeval start{0,0};
×
1370
    struct timeval  remainingTime = totalTimeout;
×
1371
    if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
1372
      gettimeofday(&start, nullptr);
×
1373
    }
×
1374

1375
    do {
×
1376
      ssize_t res = gnutls_record_recv(d_conn.get(), (reinterpret_cast<char *>(buffer) + got), bufferSize - got);
×
1377
      if (res == 0) {
×
1378
        throw std::runtime_error("EOF while reading from TLS connection");
×
1379
      }
×
1380
      else if (res > 0) {
×
1381
        got += static_cast<size_t>(res);
×
1382
        if (allowIncomplete) {
×
1383
          break;
×
1384
        }
×
1385
      }
×
1386
      else if (res < 0) {
×
1387
        if (gnutls_error_is_fatal(res)) {
×
1388
          throw std::runtime_error("Fatal error reading from TLS connection: " + std::string(gnutls_strerror(res)));
×
1389
        }
×
1390
        else if (res == GNUTLS_E_AGAIN) {
×
1391
          int result = waitForData(d_socket, readTimeout.tv_sec, readTimeout.tv_usec);
×
1392
          if (result <= 0) {
×
1393
            throw std::runtime_error("Error while waiting to read from TLS connection: " + std::to_string(result));
×
1394
          }
×
1395
        }
×
1396
        else {
×
1397
          vinfolog("Non-fatal error while reading from TLS connection: %s", gnutls_strerror(res));
×
1398
        }
×
1399
      }
×
1400

1401
      if (totalTimeout.tv_sec != 0 || totalTimeout.tv_usec != 0) {
×
1402
        struct timeval now;
×
1403
        gettimeofday(&now, nullptr);
×
1404
        struct timeval elapsed = now - start;
×
1405
        if (now < start || remainingTime < elapsed) {
×
1406
          throw runtime_error("Timeout while reading data");
×
1407
        }
×
1408
        start = now;
×
1409
        remainingTime = remainingTime - elapsed;
×
1410
      }
×
1411
    }
×
1412
    while (got < bufferSize);
×
1413

1414
    return got;
×
1415
  }
×
1416

1417
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
1418
  {
×
1419
    size_t got = 0;
×
1420

1421
    do {
×
1422
      ssize_t res = gnutls_record_send(d_conn.get(), (reinterpret_cast<const char *>(buffer) + got), bufferSize - got);
×
1423
      if (res == 0) {
×
1424
        throw std::runtime_error("Error writing to TLS connection");
×
1425
      }
×
1426
      else if (res > 0) {
×
1427
        got += static_cast<size_t>(res);
×
1428
      }
×
1429
      else if (res < 0) {
×
1430
        if (gnutls_error_is_fatal(res)) {
×
1431
          throw std::runtime_error("Fatal error writing to TLS connection: " + std::string(gnutls_strerror(res)));
×
1432
        }
×
1433
        else if (res == GNUTLS_E_AGAIN) {
×
1434
          int result = waitForRWData(d_socket, false, writeTimeout.tv_sec, writeTimeout.tv_usec);
×
1435
          if (result <= 0) {
×
1436
            throw std::runtime_error("Error waiting to write to TLS connection: " + std::to_string(result));
×
1437
          }
×
1438
        }
×
1439
        else {
×
1440
          vinfolog("Non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
×
1441
        }
×
1442
      }
×
1443
    }
×
1444
    while (got < bufferSize);
×
1445

1446
    return got;
×
1447
  }
×
1448

1449
  bool isUsable() const override
1450
  {
2✔
1451
    if (!d_conn) {
2!
1452
      return false;
×
1453
    }
×
1454

1455
    /* as far as I can tell we can't peek so we cannot do better */
1456
    return isTCPSocketUsable(d_socket);
2✔
1457
  }
2✔
1458

1459
  std::string getServerNameIndication() const override
1460
  {
112✔
1461
    if (d_conn) {
112!
1462
      unsigned int type;
112✔
1463
      size_t name_len = 256;
112✔
1464
      std::string sni;
112✔
1465
      sni.resize(name_len);
112✔
1466

1467
      int res = gnutls_server_name_get(d_conn.get(), const_cast<char*>(sni.c_str()), &name_len, &type, 0);
112✔
1468
      if (res == GNUTLS_E_SUCCESS) {
112!
1469
        sni.resize(name_len);
112✔
1470
        return sni;
112✔
1471
      }
112✔
1472
    }
112✔
1473
    return std::string();
×
1474
  }
112✔
1475

1476
  std::vector<uint8_t> getNextProtocol() const override
1477
  {
×
1478
    std::vector<uint8_t> result;
×
1479
    if (!d_conn) {
×
1480
      return result;
×
1481
    }
×
1482
    gnutls_datum_t next;
×
1483
    if (gnutls_alpn_get_selected_protocol(d_conn.get(), &next) != GNUTLS_E_SUCCESS) {
×
1484
      return result;
×
1485
    }
×
1486
    result.insert(result.end(), next.data, next.data + next.size);
×
1487
    return result;
×
1488
  }
×
1489

1490
  LibsslTLSVersion getTLSVersion() const override
1491
  {
112✔
1492
    auto proto = gnutls_protocol_get_version(d_conn.get());
112✔
1493
    switch (proto) {
112✔
1494
    case GNUTLS_TLS1_0:
×
1495
      return LibsslTLSVersion::TLS10;
×
1496
    case GNUTLS_TLS1_1:
×
1497
      return LibsslTLSVersion::TLS11;
×
1498
    case GNUTLS_TLS1_2:
3✔
1499
      return LibsslTLSVersion::TLS12;
3✔
1500
#if GNUTLS_VERSION_NUMBER >= 0x030603
×
1501
    case GNUTLS_TLS1_3:
109✔
1502
      return LibsslTLSVersion::TLS13;
109✔
1503
#endif /* GNUTLS_VERSION_NUMBER >= 0x030603 */
×
1504
    default:
×
1505
      return LibsslTLSVersion::Unknown;
×
1506
    }
112✔
1507
  }
112✔
1508

1509
  bool hasSessionBeenResumed() const override
1510
  {
46✔
1511
    if (d_conn) {
46!
1512
      return gnutls_session_is_resumed(d_conn.get()) != 0;
46✔
1513
    }
46✔
1514
    return false;
×
1515
  }
46✔
1516

1517
  std::vector<std::unique_ptr<TLSSession>> getSessions() override
1518
  {
36✔
1519
    return std::move(d_tlsSessions);
36✔
1520
  }
36✔
1521

1522
  void setSession(std::unique_ptr<TLSSession>& session) override
1523
  {
31✔
1524
    auto sess = dynamic_cast<GnuTLSSession*>(session.get());
31✔
1525
    if (!sess) {
31!
1526
      throw std::runtime_error("Unable to convert GnuTLS session");
×
1527
    }
×
1528

1529
    auto native = sess->getNative();
31✔
1530
    auto ret = gnutls_session_set_data(d_conn.get(), native.data, native.size);
31✔
1531
    if (ret != GNUTLS_E_SUCCESS) {
31!
1532
      throw std::runtime_error("Error setting up GnuTLS session: " + std::string(gnutls_strerror(ret)));
×
1533
    }
×
1534
    session.reset();
31✔
1535
  }
31✔
1536

1537
  void close() override
1538
  {
54✔
1539
    if (d_conn) {
54!
1540
      gnutls_bye(d_conn.get(), GNUTLS_SHUT_RDWR);
54✔
1541
    }
54✔
1542
  }
54✔
1543

1544
  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos)
1545
  {
60✔
1546
    std::vector<gnutls_datum_t> values;
60✔
1547
    values.reserve(protos.size());
60✔
1548
    for (const auto& proto : protos) {
60✔
1549
      gnutls_datum_t value;
60✔
1550
      value.data = const_cast<uint8_t*>(proto.data());
60✔
1551
      value.size = proto.size();
60✔
1552
      values.push_back(value);
60✔
1553
    }
60✔
1554
    unsigned int flags = 0;
60✔
1555
#if GNUTLS_VERSION_NUMBER >= 0x030500
60✔
1556
    flags |= GNUTLS_ALPN_MANDATORY;
60✔
1557
#elif defined(GNUTLS_ALPN_MAND)
1558
    flags |= GNUTLS_ALPN_MAND;
1559
#endif
1560
    return gnutls_alpn_set_protocols(d_conn.get(), values.data(), values.size(), flags);
60✔
1561
  }
60✔
1562

1563
  std::vector<int> getAsyncFDs() override
1564
  {
14✔
1565
    return {};
14✔
1566
  }
14✔
1567

1568
private:
1569
  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1570
  std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
1571
  std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
1572
  std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
1573
  std::string d_host;
1574
  bool d_client{false};
1575
  bool d_handshakeDone{false};
1576
};
1577

1578
class GnuTLSIOCtx: public TLSCtx
1579
{
1580
public:
1581
  /* server side context */
1582
  GnuTLSIOCtx(TLSFrontend& fe): d_enableTickets(fe.d_tlsConfig.d_enableTickets)
1583
  {
5✔
1584
    int rc = 0;
5✔
1585
    d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
5✔
1586

1587
    gnutls_certificate_credentials_t creds;
5✔
1588
    rc = gnutls_certificate_allocate_credentials(&creds);
5✔
1589
    if (rc != GNUTLS_E_SUCCESS) {
5!
1590
      throw std::runtime_error("Error allocating credentials for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
×
1591
    }
×
1592

1593
    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
5✔
1594
    creds = nullptr;
5✔
1595

1596
    for (const auto& pair : fe.d_tlsConfig.d_certKeyPairs) {
5✔
1597
      rc = gnutls_certificate_set_x509_key_file(d_creds.get(), pair.d_cert.c_str(), pair.d_key->c_str(), GNUTLS_X509_FMT_PEM);
5✔
1598
      if (rc != GNUTLS_E_SUCCESS) {
5!
1599
        throw std::runtime_error("Error loading certificate ('" + pair.d_cert + "') and key ('" + pair.d_key.value() + "') for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
×
1600
      }
×
1601
    }
5✔
1602

1603
#ifndef DISABLE_OCSP_STAPLING
5✔
1604
    size_t count = 0;
5✔
1605
    for (const auto& file : fe.d_tlsConfig.d_ocspFiles) {
5✔
1606
      rc = gnutls_certificate_set_ocsp_status_request_file(d_creds.get(), file.c_str(), count);
3✔
1607
      if (rc != GNUTLS_E_SUCCESS) {
3✔
1608
        warnlog("Error loading OCSP response from file '%s' for certificate ('%s') and key ('%s') for TLS context on %s: %s", file, fe.d_tlsConfig.d_certKeyPairs.at(count).d_cert, fe.d_tlsConfig.d_certKeyPairs.at(count).d_key.value(), fe.d_addr.toStringWithPort(), gnutls_strerror(rc));
1✔
1609
      }
1✔
1610
      ++count;
3✔
1611
    }
3✔
1612
#endif /* DISABLE_OCSP_STAPLING */
5✔
1613

1614
#if GNUTLS_VERSION_NUMBER >= 0x030600
5✔
1615
    rc = gnutls_certificate_set_known_dh_params(d_creds.get(), GNUTLS_SEC_PARAM_HIGH);
5✔
1616
    if (rc != GNUTLS_E_SUCCESS) {
5!
1617
      throw std::runtime_error("Error setting DH params for TLS context on " + fe.d_addr.toStringWithPort() + ": " + gnutls_strerror(rc));
×
1618
    }
×
1619
#endif
5✔
1620

1621
    rc = gnutls_priority_init(&d_priorityCache, fe.d_tlsConfig.d_ciphers.empty() ? "NORMAL" : fe.d_tlsConfig.d_ciphers.c_str(), nullptr);
5!
1622
    if (rc != GNUTLS_E_SUCCESS) {
5!
1623
      throw std::runtime_error("Error setting up TLS cipher preferences to '" + fe.d_tlsConfig.d_ciphers + "' (" + gnutls_strerror(rc) + ") on " + fe.d_addr.toStringWithPort());
×
1624
    }
×
1625

1626
    try {
5✔
1627
      if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
5!
1628
        handleTicketsKeyRotation(time(nullptr));
5✔
1629
      }
5✔
1630
      else {
×
1631
        GnuTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
×
1632
      }
×
1633
    }
5✔
1634
    catch(const std::runtime_error& e) {
5✔
1635
      throw std::runtime_error("Error generating tickets key for TLS context on " + fe.d_addr.toStringWithPort() + ": " + e.what());
×
1636
    }
×
1637
  }
5✔
1638

1639
  /* client side context */
1640
  GnuTLSIOCtx(const TLSContextParameters& params): d_contextParameters(std::make_unique<TLSContextParameters>(params)), d_enableTickets(true), d_validateCerts(params.d_validateCertificates)
1641
  {
18✔
1642
    int rc = 0;
18✔
1643

1644
    gnutls_certificate_credentials_t creds;
18✔
1645
    rc = gnutls_certificate_allocate_credentials(&creds);
18✔
1646
    if (rc != GNUTLS_E_SUCCESS) {
18!
1647
      throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
×
1648
    }
×
1649

1650
    d_creds = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
18✔
1651
    creds = nullptr;
18✔
1652

1653
    if (params.d_validateCertificates) {
18✔
1654
      if (params.d_caStore.empty()) {
14!
1655
#if GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703
1656
        /* see https://gitlab.com/gnutls/gnutls/-/issues/1277 */
1657
        std::cerr<<"Warning: GnuTLS 3.7.0 - 3.7.2 have a memory leak when validating server certificates in some configurations (PKCS11 support enabled, and a default PKCS11 trust store), please consider upgrading GnuTLS, using the OpenSSL provider for outgoing connections, or explicitly setting a CA store"<<std::endl;
1658
#endif /* GNUTLS_VERSION_NUMBER >= 0x030700 && GNUTLS_VERSION_NUMBER < 0x030703 */
1659
        rc = gnutls_certificate_set_x509_system_trust(d_creds.get());
×
1660
        if (rc < 0) {
×
1661
          throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
×
1662
        }
×
1663
      }
×
1664
      else {
14✔
1665
        rc = gnutls_certificate_set_x509_trust_file(d_creds.get(), params.d_caStore.c_str(), GNUTLS_X509_FMT_PEM);
14✔
1666
        if (rc < 0) {
14!
1667
          throw std::runtime_error("Error adding '" + params.d_caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
×
1668
        }
×
1669
      }
14✔
1670
    }
14✔
1671

1672
    rc = gnutls_priority_init(&d_priorityCache, params.d_ciphers.empty() ? "NORMAL" : params.d_ciphers.c_str(), nullptr);
18!
1673
    if (rc != GNUTLS_E_SUCCESS) {
18!
1674
      throw std::runtime_error("Error setting up TLS cipher preferences to 'NORMAL' (" + std::string(gnutls_strerror(rc)) + ")");
×
1675
    }
×
1676
  }
18✔
1677

1678
  ~GnuTLSIOCtx() override
1679
  {
2✔
1680
    d_creds.reset();
2✔
1681

1682
    if (d_priorityCache) {
2!
1683
      gnutls_priority_deinit(d_priorityCache);
2✔
1684
    }
2✔
1685
  }
2✔
1686

1687
  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
1688
  {
14✔
1689
    handleTicketsKeyRotation(now);
14✔
1690

1691
    std::shared_ptr<GnuTLSTicketsKey> ticketsKey;
14✔
1692
    {
14✔
1693
      ticketsKey = *(d_ticketsKey.read_lock());
14✔
1694
    }
14✔
1695

1696
    auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds, d_priorityCache, ticketsKey, d_enableTickets);
14✔
1697
    if (!d_protos.empty()) {
14!
1698
      connection->setALPNProtos(d_protos);
14✔
1699
    }
14✔
1700
    return connection;
14✔
1701
  }
14✔
1702

1703
  static std::shared_ptr<gnutls_certificate_credentials_st> getPerThreadCredentials(bool validate, const std::string& caStore)
1704
  {
46✔
1705
    static thread_local std::map<std::pair<bool, std::string>, std::shared_ptr<gnutls_certificate_credentials_st>> t_credentials;
46✔
1706
    auto& entry = t_credentials[{validate, caStore}];
46✔
1707
    if (!entry) {
46✔
1708
      gnutls_certificate_credentials_t creds;
18✔
1709
      int rc = gnutls_certificate_allocate_credentials(&creds);
18✔
1710
      if (rc != GNUTLS_E_SUCCESS) {
18!
1711
        throw std::runtime_error("Error allocating credentials for TLS context: " + std::string(gnutls_strerror(rc)));
×
1712
      }
×
1713

1714
      entry = std::shared_ptr<gnutls_certificate_credentials_st>(creds, gnutls_certificate_free_credentials);
18✔
1715
      creds = nullptr;
18✔
1716

1717
      if (validate) {
18✔
1718
        if (caStore.empty()) {
13!
1719
          rc = gnutls_certificate_set_x509_system_trust(entry.get());
×
1720
          if (rc < 0) {
×
1721
            throw std::runtime_error("Error adding the system's default trusted CAs: " + std::string(gnutls_strerror(rc)));
×
1722
          }
×
1723
        }
×
1724
        else {
13✔
1725
          rc = gnutls_certificate_set_x509_trust_file(entry.get(), caStore.c_str(), GNUTLS_X509_FMT_PEM);
13✔
1726
          if (rc < 0) {
13!
1727
            throw std::runtime_error("Error adding '" + caStore + "' to the trusted CAs: " + std::string(gnutls_strerror(rc)));
×
1728
          }
×
1729
        }
13✔
1730
      }
13✔
1731
    }
18✔
1732
    return entry;
46✔
1733
  }
46✔
1734

1735
  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool, int socket, const struct timeval& timeout) override
1736
  {
46✔
1737
    auto creds = getPerThreadCredentials(d_contextParameters->d_validateCertificates, d_contextParameters->d_caStore);
46✔
1738
    auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, creds, d_priorityCache, d_validateCerts);
46✔
1739
    if (!d_protos.empty()) {
46!
1740
      connection->setALPNProtos(d_protos);
46✔
1741
    }
46✔
1742
    return connection;
46✔
1743
  }
46✔
1744

1745
  void addTicketsKey(time_t now, std::shared_ptr<GnuTLSTicketsKey>&& newKey)
1746
  {
5✔
1747
    if (!d_enableTickets) {
5!
1748
      return;
×
1749
    }
×
1750

1751
    {
5✔
1752
      *(d_ticketsKey.write_lock()) = std::move(newKey);
5✔
1753
    }
5✔
1754

1755
    if (d_ticketsKeyRotationDelay > 0) {
5!
1756
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
5✔
1757
    }
5✔
1758

1759
    if (TLSCtx::hasTicketsKeyAddedHook()) {
5!
1760
      auto ticketsKey = *(d_ticketsKey.read_lock());
×
1761
      auto content = ticketsKey->content();
×
1762
      TLSCtx::getTicketsKeyAddedHook()(content);
×
1763
      safe_memory_release(content.data(), content.size());
×
1764
    }
×
1765
  }
5✔
1766
  void rotateTicketsKey(time_t now) override
1767
  {
5✔
1768
    if (!d_enableTickets) {
5!
1769
      return;
×
1770
    }
×
1771

1772
    auto newKey = std::make_shared<GnuTLSTicketsKey>();
5✔
1773
    addTicketsKey(now, std::move(newKey));
5✔
1774
  }
5✔
1775
  void loadTicketsKeys(const std::string& file) final
1776
  {
×
1777
    if (!d_enableTickets) {
×
1778
      return;
×
1779
    }
×
1780

1781
    auto newKey = std::make_shared<GnuTLSTicketsKey>(file);
×
1782
    addTicketsKey(time(nullptr), std::move(newKey));
×
1783
  }
×
1784

1785
  size_t getTicketsKeysCount() override
1786
  {
×
1787
    return *(d_ticketsKey.read_lock()) != nullptr ? 1 : 0;
×
1788
  }
×
1789

1790
  std::string getName() const override
1791
  {
3✔
1792
    return "gnutls";
3✔
1793
  }
3✔
1794

1795
  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
1796
  {
23✔
1797
#ifdef HAVE_GNUTLS_ALPN_SET_PROTOCOLS
23✔
1798
    d_protos = protos;
23✔
1799
    return true;
23✔
1800
#else
1801
    return false;
1802
#endif
1803
  }
23✔
1804

1805
private:
1806
  /* client context parameters */
1807
  std::unique_ptr<TLSContextParameters> d_contextParameters{nullptr};
1808
  std::shared_ptr<gnutls_certificate_credentials_st> d_creds;
1809
  std::vector<std::vector<uint8_t>> d_protos;
1810
  gnutls_priority_t d_priorityCache{nullptr};
1811
  SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};
1812
  bool d_enableTickets{true};
1813
  bool d_validateCerts{true};
1814
};
1815

1816
#endif /* HAVE_GNUTLS */
1817

1818
#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
1819

1820
bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
1821
{
45✔
1822
  if (ctx == nullptr) {
45!
1823
    return false;
×
1824
  }
×
1825
  /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */
1826
  const std::vector<std::vector<uint8_t>> dotAlpns = {{'d', 'o', 't'}};
45✔
1827
  ctx->setALPNProtos(dotAlpns);
45✔
1828
  return true;
45✔
1829
}
45✔
1830

1831
bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
1832
{
28✔
1833
  if (ctx == nullptr) {
28!
1834
    return false;
×
1835
  }
×
1836
  /* This code is only called for incoming/server TLS contexts (not outgoing/client),
1837
     and h2o sets it own ALPN values.
1838
     We want to set the ALPN for DoH:
1839
     - HTTP/1.1 so that the OpenSSL callback ALPN accepts it, letting us later return a static response
1840
     - HTTP/2
1841
  */
1842
  const std::vector<std::vector<uint8_t>> dohAlpns{{'h', '2'},{'h', 't', 't', 'p', '/', '1', '.', '1'}};
28✔
1843
  ctx->setALPNProtos(dohAlpns);
28✔
1844

1845
  return true;
28✔
1846
}
28✔
1847

1848
bool TLSFrontend::setupTLS()
1849
{
54✔
1850
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
54✔
1851
  std::shared_ptr<TLSCtx> newCtx{nullptr};
54✔
1852
  /* get the "best" available provider */
1853
#if defined(HAVE_GNUTLS)
54✔
1854
  if (d_provider == "gnutls") {
54✔
1855
    newCtx = std::make_shared<GnuTLSIOCtx>(*this);
5✔
1856
  }
5✔
1857
#endif /* HAVE_GNUTLS */
54✔
1858
#if defined(HAVE_LIBSSL)
54✔
1859
  if (d_provider == "openssl") {
54✔
1860
    newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
16✔
1861
  }
16✔
1862
#endif /* HAVE_LIBSSL */
54✔
1863

1864
  if (!newCtx) {
54✔
1865
#if defined(HAVE_LIBSSL)
33✔
1866
    newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
33✔
1867
#elif defined(HAVE_GNUTLS)
1868
    newCtx = std::make_shared<GnuTLSIOCtx>(*this);
1869
#else
1870
#error "TLS support needed but neither libssl nor GnuTLS were selected"
1871
#endif
1872
  }
33✔
1873

1874
  if (d_alpn == ALPN::DoT) {
54✔
1875
    setupDoTProtocolNegotiation(newCtx);
26✔
1876
  }
26✔
1877
  else if (d_alpn == ALPN::DoH) {
28!
1878
    setupDoHProtocolNegotiation(newCtx);
28✔
1879
  }
28✔
1880

1881
  std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
54✔
1882
#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
54✔
1883
  return true;
54✔
1884
}
54✔
1885

1886
std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params)
1887
{
43✔
1888
#ifdef HAVE_DNS_OVER_TLS
43✔
1889
  /* get the "best" available provider */
1890
  if (!params.d_provider.empty()) {
43✔
1891
#if defined(HAVE_GNUTLS)
36✔
1892
    if (params.d_provider == "gnutls") {
36✔
1893
      return std::make_shared<GnuTLSIOCtx>(params);
18✔
1894
    }
18✔
1895
#endif /* HAVE_GNUTLS */
18✔
1896
#if defined(HAVE_LIBSSL)
18✔
1897
    if (params.d_provider == "openssl") {
18!
1898
      return std::make_shared<OpenSSLTLSIOCtx>(params);
18✔
1899
    }
18✔
1900
#endif /* HAVE_LIBSSL */
18✔
1901
  }
18✔
1902

1903
#if defined(HAVE_LIBSSL)
7✔
1904
  return std::make_shared<OpenSSLTLSIOCtx>(params);
7✔
1905
#elif defined(HAVE_GNUTLS)
1906
  return std::make_shared<GnuTLSIOCtx>(params);
1907
#else
1908
#error "DNS over TLS support needed but neither libssl nor GnuTLS were selected"
1909
#endif
1910

1911
#endif /* HAVE_DNS_OVER_TLS */
×
1912
  return nullptr;
×
1913
}
43✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc