• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 17235120617

26 Aug 2025 10:17AM UTC coverage: 65.959% (-0.02%) from 65.977%
17235120617

Pull #16016

github

web-flow
Merge d1e0ec6fc into 9eeac00a7
Pull Request #16016: auth: random doc nits

42117 of 92446 branches covered (45.56%)

Branch coverage included in aggregate %.

128034 of 165518 relevant lines covered (77.35%)

5925196.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.62
/pdns/dnsdistdist/doh.cc
1
#include "config.h"
2
#include "doh.hh"
3

4
#ifdef HAVE_DNS_OVER_HTTPS
5
#ifdef HAVE_LIBH2OEVLOOP
6
#define H2O_USE_EPOLL 1
7

8
#include <cerrno>
9
#include <iostream>
10
#include <thread>
11
#include <string_view>
12

13
#include <boost/algorithm/string.hpp>
14
#include <h2o.h>
15
#include <h2o/http2.h>
16

17
#include <openssl/err.h>
18
#include <openssl/ssl.h>
19

20
#include "base64.hh"
21
#include "dnsname.hh"
22
#undef CERT
23
#include "dnsdist.hh"
24
#include "dnsdist-tcp.hh"
25
#include "misc.hh"
26
#include "dns.hh"
27
#include "dolog.hh"
28
#include "dnsdist-concurrent-connections.hh"
29
#include "dnsdist-dnsparser.hh"
30
#include "dnsdist-ecs.hh"
31
#include "dnsdist-metrics.hh"
32
#include "dnsdist-proxy-protocol.hh"
33
#include "libssl.hh"
34
#include "threadname.hh"
35

36
/* So, how does this work. We use h2o for our http2 and TLS needs.
37
   If the operator has configured multiple IP addresses to listen on,
38
   we launch multiple h2o listener threads. We can hook in to multiple
39
   URLs though on the same IP. There is no SNI yet (I think).
40

41
   h2o is event driven, so we get callbacks if a new DNS query arrived.
42
   When it does, we do some minimal parsing on it, and send it on to the
43
   dnsdist worker thread which we also launched.
44

45
   This dnsdist worker thread injects the query into the normal dnsdist flow
46
   (over a pipe). The response also goes back over a (different) pipe,
47
   where we pick it up and deliver it back to h2o.
48

49
   For coordination, we use the h2o socket multiplexer, which is sensitive to our
50
   pipe too.
51
*/
52

53
/* h2o notes.
54
   Paths and parameters etc just *happen* to be null-terminated in HTTP2.
55
   They are not in HTTP1. So you MUST use the length field!
56
*/
57

58
/* 'Intermediate' compatibility from https://wiki.mozilla.org/Security/Server_Side_TLS#Intermediate_compatibility_.28default.29 */
59
static constexpr std::string_view DOH_DEFAULT_CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA:!DSS";
60

61
class DOHAcceptContext
62
{
63
public:
64
  DOHAcceptContext()
65
  {
25✔
66
    memset(&d_h2o_accept_ctx, 0, sizeof(d_h2o_accept_ctx));
25✔
67
    d_rotatingTicketsKey.clear();
25✔
68
  }
25✔
69
  DOHAcceptContext(const DOHAcceptContext&) = delete;
70
  DOHAcceptContext(DOHAcceptContext&&) = delete;
71
  DOHAcceptContext& operator=(const DOHAcceptContext&) = delete;
72
  DOHAcceptContext& operator=(DOHAcceptContext&&) = delete;
73

74
  h2o_accept_ctx_t* get()
75
  {
183✔
76
    return &d_h2o_accept_ctx;
183✔
77
  }
183✔
78

79
  ~DOHAcceptContext()
80
  {
2✔
81
    SSL_CTX_free(d_h2o_accept_ctx.ssl_ctx);
2✔
82
    d_h2o_accept_ctx.ssl_ctx = nullptr;
2✔
83
  }
2✔
84

85
  void decrementConcurrentConnections() const
86
  {
134✔
87
    if (d_cs != nullptr) {
134!
88
      --d_cs->tcpCurrentConnections;
134✔
89
    }
134✔
90
  }
134✔
91

92
  [[nodiscard]] time_t getNextTicketsKeyRotation() const
93
  {
×
94
    return d_ticketsKeyNextRotation;
×
95
  }
×
96

97
  [[nodiscard]] size_t getTicketsKeysCount() const
98
  {
×
99
    size_t res = 0;
×
100
    if (d_ticketKeys) {
×
101
      res = d_ticketKeys->getKeysCount();
×
102
    }
×
103
    return res;
×
104
  }
×
105

106
  void rotateTicketsKey(time_t now)
107
  {
46✔
108
    if (!d_ticketKeys) {
46✔
109
      return;
1✔
110
    }
1✔
111

112
    d_ticketKeys->rotateTicketsKey(now);
45✔
113

114
    if (d_ticketsKeyRotationDelay > 0) {
45!
115
      d_ticketsKeyNextRotation = now + d_ticketsKeyRotationDelay;
45✔
116
    }
45✔
117
  }
45✔
118

119
  void loadTicketsKeys(const std::string& keyFile)
120
  {
6✔
121
    if (!d_ticketKeys) {
6!
122
      return;
×
123
    }
×
124
    d_ticketKeys->loadTicketsKeys(keyFile);
6✔
125

126
    if (d_ticketsKeyRotationDelay > 0) {
6!
127
      d_ticketsKeyNextRotation = time(nullptr) + d_ticketsKeyRotationDelay;
6✔
128
    }
6✔
129
  }
6✔
130

131
  void handleTicketsKeyRotation()
132
  {
278✔
133
    if (d_ticketsKeyRotationDelay == 0) {
278!
134
      return;
×
135
    }
×
136

137
    time_t now = time(nullptr);
278✔
138
    if (now > d_ticketsKeyNextRotation) {
278✔
139
      if (d_rotatingTicketsKey.test_and_set()) {
24!
140
        /* someone is already rotating */
141
        return;
×
142
      }
×
143
      try {
24✔
144
        rotateTicketsKey(now);
24✔
145

146
        d_rotatingTicketsKey.clear();
24✔
147
      }
24✔
148
      catch(const std::runtime_error& e) {
24✔
149
        d_rotatingTicketsKey.clear();
×
150
        throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what());
×
151
      }
×
152
      catch(...) {
24✔
153
        d_rotatingTicketsKey.clear();
×
154
        throw;
×
155
      }
×
156
    }
24✔
157
  }
278✔
158

159
  std::map<int, std::string> d_ocspResponses;
160
  std::unique_ptr<OpenSSLTLSTicketKeysRing> d_ticketKeys{nullptr};
161
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
162
  pdns::UniqueFilePtr d_keyLogFile{nullptr};
163
  ClientState* d_cs{nullptr};
164
  time_t d_ticketsKeyRotationDelay{0};
165

166
private:
167
  h2o_accept_ctx_t d_h2o_accept_ctx{};
168
  time_t d_ticketsKeyNextRotation{0};
169
  std::atomic_flag d_rotatingTicketsKey;
170
};
171

172
struct DOHUnit;
173

174
// we create one of these per thread, and pass around a pointer to it
175
// through the bowels of h2o
176
struct DOHServerConfig
177
{
178
  DOHServerConfig(uint32_t idleTimeout, uint32_t internalPipeBufferSize): accept_ctx(std::make_shared<DOHAcceptContext>())
23✔
179
  {
23✔
180
#ifndef USE_SINGLE_ACCEPTOR_THREAD
23✔
181
    {
23✔
182
      auto [sender, receiver] = pdns::channel::createObjectQueue<DOHUnit>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverBlocking, internalPipeBufferSize);
23✔
183
      d_querySender = std::move(sender);
23✔
184
      d_queryReceiver = std::move(receiver);
23✔
185
    }
23✔
186
#endif /* USE_SINGLE_ACCEPTOR_THREAD */
23✔
187

188
    {
23✔
189
      auto [sender, receiver] = pdns::channel::createObjectQueue<DOHUnit>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
23✔
190
      d_responseSender = std::move(sender);
23✔
191
      d_responseReceiver = std::move(receiver);
23✔
192
    }
23✔
193

194
    h2o_config_init(&h2o_config);
23✔
195
    h2o_config.http2.idle_timeout = static_cast<uint64_t>(idleTimeout) * 1000;
23✔
196
    /* if you came here for a way to make the number of concurrent streams (concurrent requests per connection)
197
       configurable, or even just bigger, I have bad news for you.
198
       h2o_config.http2.max_concurrent_requests_per_connection (default of 100) is capped by
199
       H2O_HTTP2_SETTINGS_HOST.max_concurrent_streams which is not configurable. Even if decided to change the
200
       hard-coded value, libh2o's author warns that there might be parts of the code where the stream ID is stored
201
       in 8 bits, making 256 a hard value: https://github.com/h2o/h2o/issues/805
202
    */
203
  }
23✔
204
  DOHServerConfig(const DOHServerConfig&) = delete;
205
  DOHServerConfig(DOHServerConfig&&) = delete;
206
  DOHServerConfig& operator=(const DOHServerConfig&) = delete;
207
  DOHServerConfig& operator=(DOHServerConfig&&) = delete;
208
  ~DOHServerConfig() = default;
×
209

210
  std::set<std::string, std::less<>> paths;
211
  h2o_globalconf_t h2o_config{};
212
  h2o_context_t h2o_ctx{};
213
  std::unique_ptr<h2o_socket_t,decltype(&h2o_socket_close)> h2o_socket{nullptr, h2o_socket_close};
214
  std::shared_ptr<DOHAcceptContext> accept_ctx{nullptr};
215
  ClientState* clientState{nullptr};
216
  std::shared_ptr<DOHFrontend> dohFrontend{nullptr};
217
#ifndef USE_SINGLE_ACCEPTOR_THREAD
218
  pdns::channel::Sender<DOHUnit> d_querySender;
219
  pdns::channel::Receiver<DOHUnit> d_queryReceiver;
220
#endif /* USE_SINGLE_ACCEPTOR_THREAD */
221
  pdns::channel::Sender<DOHUnit> d_responseSender;
222
  pdns::channel::Receiver<DOHUnit> d_responseReceiver;
223
};
224

225
struct DOHUnit : public DOHUnitInterface
226
{
227
  DOHUnit(PacketBuffer&& query_, std::string&& path_, std::string&& host_): path(std::move(path_)), host(std::move(host_)), query(std::move(query_))
94✔
228
  {
94✔
229
    ids.ednsAdded = false;
94✔
230
  }
94✔
231
  ~DOHUnit() override
232
  {
94✔
233
    if (self != nullptr) {
94!
234
      *self = nullptr;
×
235
    }
×
236
  }
94✔
237

238
  DOHUnit(const DOHUnit&) = delete;
239
  DOHUnit(DOHUnit&&) = delete;
240
  DOHUnit& operator=(const DOHUnit&) = delete;
241
  DOHUnit& operator=(DOHUnit&&) = delete;
242

243
  InternalQueryState ids;
244
  std::string sni;
245
  std::string path;
246
  std::string scheme;
247
  std::string host;
248
  std::string contentType;
249
  PacketBuffer query;
250
  PacketBuffer response;
251
  std::unique_ptr<std::unordered_map<std::string, std::string>> headers;
252
  st_h2o_req_t* req{nullptr};
253
  DOHUnit** self{nullptr};
254
  DOHServerConfig* dsc{nullptr};
255
  pdns::channel::Sender<DOHUnit>* responseSender{nullptr};
256
  size_t query_at{0};
257
  int rsock{-1};
258
  /* the status_code is set from
259
     processDOHQuery() (which is executed in
260
     the DOH client thread) so that the correct
261
     response can be sent in on_dnsdist(),
262
     after the DOHUnit has been passed back to
263
     the main DoH thread.
264
  */
265
  uint16_t status_code{200};
266
  /* whether the query was re-sent to the backend over
267
     TCP after receiving a truncated answer over UDP */
268
  bool tcp{false};
269
  bool truncated{false};
270

271
  [[nodiscard]] std::string getHTTPPath() const override;
272
  [[nodiscard]] std::string getHTTPQueryString() const override;
273
  [[nodiscard]] const std::string& getHTTPHost() const override;
274
  [[nodiscard]] const std::string& getHTTPScheme() const override;
275
  [[nodiscard]] const std::unordered_map<std::string, std::string>& getHTTPHeaders() const override;
276
  [[nodiscard]] std::shared_ptr<TCPQuerySender> getQuerySender() const override
277
  {
×
278
    return nullptr;
×
279
  }
×
280
  void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType="") override;
281
  void handleTimeout() override;
282
  void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, [[maybe_unused]] const std::shared_ptr<DownstreamState>& downstream) override;
283
};
284
using DOHUnitUniquePtr = std::unique_ptr<DOHUnit>;
285

286
/* This internal function sends back the object to the main thread to send a reply.
287
   The caller should NOT release or touch the unit after calling this function */
288
static void sendDoHUnitToTheMainThread(DOHUnitUniquePtr&& dohUnit, const char* description)
289
{
98✔
290
  if (dohUnit->responseSender == nullptr) {
98!
291
    return;
×
292
  }
×
293
  try {
98✔
294
    if (!dohUnit->responseSender->send(std::move(dohUnit))) {
98!
295
      ++dnsdist::metrics::g_stats.dohResponsePipeFull;
×
296
      vinfolog("Unable to pass a %s to the DoH worker thread because the pipe is full", description);
×
297
    }
×
298
  } catch (const std::exception& e) {
98✔
299
    vinfolog("Unable to pass a %s to the DoH worker thread because we couldn't write to the pipe: %s", description, e.what());
×
300
  }
×
301
}
98✔
302

303
/* This function is called from other threads than the main DoH one,
304
   instructing it to send a 502 error to the client. */
305
void DOHUnit::handleTimeout()
306
{
×
307
  status_code = 502;
×
308
  sendDoHUnitToTheMainThread(std::unique_ptr<DOHUnit>(this), "DoH timeout");
×
309
}
×
310

311
struct DOHConnection
312
{
313
  std::shared_ptr<DOHAcceptContext> d_acceptCtx{nullptr};
314
  ComboAddress d_remote;
315
  ComboAddress d_local;
316
  struct timeval d_connectionStartTime{0, 0};
317
  size_t d_nbQueries{0};
318
  int d_desc{-1};
319
  uint8_t d_concurrentStreams{0};
320
};
321

322
static thread_local std::unordered_map<int, DOHConnection> t_conns;
323

324
static void on_socketclose(void *data)
325
{
134✔
326
  auto* conn = static_cast<DOHConnection*>(data);
134✔
327
  if (conn != nullptr) {
134!
328
    if (conn->d_acceptCtx) {
134!
329
      struct timeval now{};
134✔
330
      gettimeofday(&now, nullptr);
134✔
331

332
      auto diff = now - conn->d_connectionStartTime;
134✔
333

334
      conn->d_acceptCtx->decrementConcurrentConnections();
134✔
335
      conn->d_acceptCtx->d_cs->updateTCPMetrics(conn->d_nbQueries, diff.tv_sec * 1000 + diff.tv_usec / 1000, 0);
134✔
336
    }
134✔
337

338
    dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(conn->d_remote);
134✔
339
    // you can no longer touch conn, or data, after this call
340
    t_conns.erase(conn->d_desc);
134✔
341
  }
134✔
342
}
134✔
343

344
static const std::string& getReasonFromStatusCode(uint16_t statusCode)
345
{
12✔
346
  /* no need to care too much about this, HTTP/2 has no 'reason' anyway */
347
  static const std::unordered_map<uint16_t, std::string> reasons = {
12✔
348
    { 200, "OK" },
12✔
349
    { 301, "Moved Permanently" },
12✔
350
    { 302, "Found" },
12✔
351
    { 303, "See Other" },
12✔
352
    { 304, "Not Modified" },
12✔
353
    { 305, "Use Proxy" },
12✔
354
    { 306, "Switch Proxy" },
12✔
355
    { 307, "Temporary Redirect" },
12✔
356
    { 308, "Permanent Redirect" },
12✔
357
    { 400, "Bad Request" },
12✔
358
    { 401, "Unauthorized" },
12✔
359
    { 402, "Payment Required" },
12✔
360
    { 403, "Forbidden" },
12✔
361
    { 404, "Not Found" },
12✔
362
    { 405, "Method Not Allowed" },
12✔
363
    { 406, "Not Acceptable" },
12✔
364
    { 407, "Proxy Authentication Required" },
12✔
365
    { 408, "Request Timeout" },
12✔
366
    { 409, "Conflict" },
12✔
367
    { 410, "Gone" },
12✔
368
    { 411, "Length Required" },
12✔
369
    { 412, "Precondition Failed" },
12✔
370
    { 413, "Payload Too Large" },
12✔
371
    { 414, "URI Too Long" },
12✔
372
    { 415, "Unsupported Media Type" },
12✔
373
    { 416, "Range Not Satisfiable" },
12✔
374
    { 417, "Expectation Failed" },
12✔
375
    { 418, "I'm a teapot" },
12✔
376
    { 451, "Unavailable For Legal Reasons" },
12✔
377
    { 500, "Internal Server Error" },
12✔
378
    { 501, "Not Implemented" },
12✔
379
    { 502, "Bad Gateway" },
12✔
380
    { 503, "Service Unavailable" },
12✔
381
    { 504, "Gateway Timeout" },
12✔
382
    { 505, "HTTP Version Not Supported" }
12✔
383
  };
12✔
384
  static const std::string unknown = "Unknown";
12✔
385

386
  const auto reasonIt = reasons.find(statusCode);
12✔
387
  if (reasonIt == reasons.end()) {
12!
388
    return unknown;
×
389
  }
×
390
  return reasonIt->second;
12✔
391
}
12✔
392

393
static DOHConnection* getConnectionFromQuery(const h2o_req_t* req)
394
{
301✔
395
  h2o_socket_t* sock = req->conn->callbacks->get_socket(req->conn);
301✔
396
  const int descriptor = h2o_socket_get_fd(sock);
301✔
397
  if (descriptor == -1) {
301!
398
    /* this should not happen, but let's not crash on it */
399
    return nullptr;
×
400
  }
×
401
  return &t_conns.at(descriptor);
301✔
402
}
301✔
403

404
/* Always called from the main DoH thread */
405
static void handleResponse(DOHFrontend& dohFrontend, st_h2o_req_t* req, uint16_t statusCode, const PacketBuffer& response, const std::unordered_map<std::string, std::string>& customResponseHeaders, const std::string& contentType, bool addContentType)
406
{
96✔
407
  constexpr int overwrite_if_exists = 1;
96✔
408
  constexpr int maybe_token = 1;
96✔
409
  for (auto const& headerPair : customResponseHeaders) {
96✔
410
    h2o_set_header_by_str(&req->pool, &req->res.headers, headerPair.first.c_str(), headerPair.first.size(), maybe_token, headerPair.second.c_str(), headerPair.second.size(), overwrite_if_exists);
65✔
411
  }
65✔
412

413
  if (statusCode == 200) {
96✔
414
    ++dohFrontend.d_validresponses;
84✔
415
    req->res.status = 200;
84✔
416

417
    if (addContentType) {
84!
418
      if (contentType.empty()) {
84✔
419
        h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, nullptr, H2O_STRLIT("application/dns-message"));
81✔
420
      }
81✔
421
      else {
3✔
422
        /* we need to duplicate the header content because h2o keeps a pointer and we will be deleted before the response has been sent */
423
        h2o_iovec_t contentTypeVect = h2o_strdup(&req->pool, contentType.c_str(), contentType.size());
3✔
424
        // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay,cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API
425
        h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, nullptr, contentTypeVect.base, contentTypeVect.len);
3✔
426
      }
3✔
427
    }
84✔
428

429
    if (dohFrontend.d_sendCacheControlHeaders && response.size() > sizeof(dnsheader)) {
84✔
430
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
431
      uint32_t minTTL = getDNSPacketMinTTL(reinterpret_cast<const char*>(response.data()), response.size());
81✔
432
      if (minTTL != std::numeric_limits<uint32_t>::max()) {
81✔
433
        std::string cacheControlValue = "max-age=" + std::to_string(minTTL);
61✔
434
        /* we need to duplicate the header content because h2o keeps a pointer and we will be deleted before the response has been sent */
435
        h2o_iovec_t ccv = h2o_strdup(&req->pool, cacheControlValue.c_str(), cacheControlValue.size());
61✔
436
        h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CACHE_CONTROL, nullptr, ccv.base, ccv.len);
61✔
437
      }
61✔
438
    }
81✔
439

440
    req->res.content_length = response.size();
84✔
441
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API
442
    h2o_send_inline(req, reinterpret_cast<const char*>(response.data()), response.size());
84✔
443
  }
84✔
444
  else if (statusCode >= 300 && statusCode < 400) {
12!
445
    /* in that case the response is actually a URL */
446
    /* we need to duplicate the URL because h2o uses it for the location header, keeping a pointer, and we will be deleted before the response has been sent */
447
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API
448
    h2o_iovec_t url = h2o_strdup(&req->pool, reinterpret_cast<const char*>(response.data()), response.size());
1✔
449
    h2o_send_redirect(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), url.base, url.len);
1✔
450
    ++dohFrontend.d_redirectresponses;
1✔
451
  }
1✔
452
  else {
11✔
453
    // we need to make sure it's null-terminated */
454
    if (!response.empty() && response.at(response.size() - 1) == 0) {
11!
455
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API
456
      h2o_send_error_generic(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), reinterpret_cast<const char*>(response.data()), H2O_SEND_ERROR_KEEP_HEADERS);
2✔
457
    }
2✔
458
    else {
9✔
459
      switch(statusCode) {
9✔
460
      case 400:
1✔
461
        h2o_send_error_400(req, getReasonFromStatusCode(statusCode).c_str(), "invalid DNS query" , 0);
1✔
462
        break;
1✔
463
      case 403:
1✔
464
        h2o_send_error_403(req, getReasonFromStatusCode(statusCode).c_str(), "DoH query not allowed", 0);
1✔
465
        break;
1✔
466
      case 502:
7✔
467
        h2o_send_error_502(req, getReasonFromStatusCode(statusCode).c_str(), "no downstream server available", 0);
7✔
468
        break;
7✔
469
      case 500:
×
470
        /* fall-through */
471
      default:
×
472
        h2o_send_error_500(req, getReasonFromStatusCode(statusCode).c_str(), "Internal Server Error", 0);
×
473
        break;
×
474
      }
9✔
475
    }
9✔
476

477
    ++dohFrontend.d_errorresponses;
11✔
478
  }
11✔
479

480
  if (auto* conn = getConnectionFromQuery(req)) {
96!
481
    --conn->d_concurrentStreams;
96✔
482
  }
96✔
483
}
96✔
484

485
static std::unique_ptr<DOHUnit> getDUFromIDS(InternalQueryState& ids)
486
{
203✔
487
  auto dohUnit = std::unique_ptr<DOHUnit>(dynamic_cast<DOHUnit*>(ids.du.release()));
203✔
488
  return dohUnit;
203✔
489
}
203✔
490

491
class DoHTCPCrossQuerySender final : public TCPQuerySender
492
{
493
public:
494
  DoHTCPCrossQuerySender() = default;
758✔
495
  DoHTCPCrossQuerySender(const DoHTCPCrossQuerySender&) = delete;
496
  DoHTCPCrossQuerySender(DoHTCPCrossQuerySender&&) = delete;
497
  DoHTCPCrossQuerySender& operator=(const DoHTCPCrossQuerySender&) = delete;
498
  DoHTCPCrossQuerySender& operator=(DoHTCPCrossQuerySender&&) = delete;
499
  ~DoHTCPCrossQuerySender() final = default;
500

501
  [[nodiscard]] bool active() const override
502
  {
34✔
503
    return true;
34✔
504
  }
34✔
505

506
  void handleResponse(const struct timeval& now, TCPResponse&& response) override
507
  {
45✔
508
    (void)now;
45✔
509
    if (!response.d_idstate.du) {
45!
510
      return;
×
511
    }
×
512

513
    auto dohUnit = getDUFromIDS(response.d_idstate);
45✔
514
    if (dohUnit->responseSender == nullptr) {
45!
515
      return;
×
516
    }
×
517

518
    dohUnit->response = std::move(response.d_buffer);
45✔
519
    dohUnit->ids = std::move(response.d_idstate);
45✔
520
    DNSResponse dr(dohUnit->ids, dohUnit->response, dohUnit->downstream);
45✔
521

522
    dnsheader cleartextDH{};
45✔
523
    memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH));
45✔
524

525
    if (!response.isAsync()) {
45✔
526
      dr.ids.du = std::move(dohUnit);
17✔
527

528
      if (!processResponse(dynamic_cast<DOHUnit*>(dr.ids.du.get())->response, dr, false)) {
17!
529
        if (dr.ids.du) {
×
530
          dohUnit = getDUFromIDS(dr.ids);
×
531
          dohUnit->status_code = 503;
×
532
          sendDoHUnitToTheMainThread(std::move(dohUnit), "Response dropped by rules");
×
533
        }
×
534
        return;
×
535
      }
×
536

537
      if (dr.isAsynchronous()) {
17✔
538
        return;
14✔
539
      }
14✔
540

541
      dohUnit = getDUFromIDS(dr.ids);
3✔
542
    }
3✔
543

544
    if (!dohUnit->ids.selfGenerated) {
31✔
545
      double udiff = dohUnit->ids.queryRealTime.udiff();
25✔
546
      vinfolog("Got answer from %s, relayed to %s (https), took %f us", dohUnit->downstream->d_config.remote.toStringWithPort(), dohUnit->ids.origRemote.toStringWithPort(), udiff);
25✔
547

548
      auto backendProtocol = dohUnit->downstream->getProtocol();
25✔
549
      if (backendProtocol == dnsdist::Protocol::DoUDP && dohUnit->tcp) {
25✔
550
        backendProtocol = dnsdist::Protocol::DoTCP;
4✔
551
      }
4✔
552
      handleResponseSent(dohUnit->ids, udiff, dohUnit->ids.origRemote, dohUnit->downstream->d_config.remote, dohUnit->response.size(), cleartextDH, backendProtocol, true);
25✔
553
    }
25✔
554

555
    ++dnsdist::metrics::g_stats.responses;
31✔
556
    if (dohUnit->ids.cs != nullptr) {
31!
557
      ++dohUnit->ids.cs->responses;
31✔
558
    }
31✔
559

560
    sendDoHUnitToTheMainThread(std::move(dohUnit), "cross-protocol response");
31✔
561
  }
31✔
562

563
  void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
564
  {
×
565
    return handleResponse(now, std::move(response));
×
566
  }
×
567

568
  void notifyIOError(const struct timeval& now, TCPResponse&& response) override
569
  {
6✔
570
    (void)now;
6✔
571
    auto& query = response.d_idstate;
6✔
572
    if (!query.du) {
6!
573
      return;
×
574
    }
×
575

576
    auto dohUnit = getDUFromIDS(query);
6✔
577
    if (dohUnit->responseSender == nullptr) {
6!
578
      return;
×
579
    }
×
580

581
    dohUnit->ids = std::move(query);
6✔
582
    dohUnit->status_code = 502;
6✔
583
    sendDoHUnitToTheMainThread(std::move(dohUnit), "cross-protocol error response");
6✔
584
  }
6✔
585
};
586

587
class DoHCrossProtocolQuery : public CrossProtocolQuery
588
{
589
public:
590
  DoHCrossProtocolQuery(DOHUnitUniquePtr&& dohUnit, bool isResponse)
591
  {
67✔
592
    if (isResponse) {
67✔
593
      /* happens when a response becomes async */
594
      query = InternalQuery(std::move(dohUnit->response), std::move(dohUnit->ids));
28✔
595
    }
28✔
596
    else {
39✔
597
      /* we need to duplicate the query here because we might need
598
         the existing query later if we get a truncated answer */
599
      query = InternalQuery(PacketBuffer(dohUnit->query), std::move(dohUnit->ids));
39✔
600
    }
39✔
601

602
    /* it might have been moved when we moved dohUnit->ids */
603
    if (dohUnit) {
67!
604
      query.d_idstate.du = std::move(dohUnit);
67✔
605
    }
67✔
606

607
    /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload,
608
       clearing query.d_proxyProtocolPayloadAdded and dohUnit->proxyProtocolPayloadSize.
609
       Leave it for now because we know that the onky case where the payload has been
610
       added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT,
611
       and we know the TCP/DoT code can handle it. */
612
    query.d_proxyProtocolPayloadAdded = query.d_idstate.d_proxyProtocolPayloadSize > 0;
67✔
613
    downstream = query.d_idstate.du->downstream;
67✔
614
  }
67✔
615

616
  void handleInternalError()
617
  {
×
618
    auto dohUnit = getDUFromIDS(query.d_idstate);
×
619
    if (dohUnit == nullptr) {
×
620
      return;
×
621
    }
×
622
    dohUnit->status_code = 502;
×
623
    sendDoHUnitToTheMainThread(std::move(dohUnit), "DoH internal error");
×
624
  }
×
625

626
  std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
627
  {
51✔
628
    auto* unit = dynamic_cast<DOHUnit*>(query.d_idstate.du.get());
51✔
629
    if (unit != nullptr) {
51!
630
      unit->downstream = downstream;
51✔
631
    }
51✔
632
    return s_sender;
51✔
633
  }
51✔
634

635
  DNSQuestion getDQ() override
636
  {
55✔
637
    auto& ids = query.d_idstate;
55✔
638
    DNSQuestion dq(ids, query.d_buffer);
55✔
639
    return dq;
55✔
640
  }
55✔
641

642
  DNSResponse getDR() override
643
  {
24✔
644
    auto& ids = query.d_idstate;
24✔
645
    DNSResponse dr(ids, query.d_buffer, downstream);
24✔
646
    return dr;
24✔
647
   }
24✔
648

649
  DOHUnitUniquePtr releaseDU()
650
  {
×
651
    return getDUFromIDS(query.d_idstate);
×
652
  }
×
653

654
private:
655
  static std::shared_ptr<DoHTCPCrossQuerySender> s_sender;
656
};
657

658
std::shared_ptr<DoHTCPCrossQuerySender> DoHCrossProtocolQuery::s_sender = std::make_shared<DoHTCPCrossQuerySender>();
659

660
std::unique_ptr<CrossProtocolQuery> getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse)
661
{
62✔
662
  if (!dq.ids.du) {
62!
663
    throw std::runtime_error("Trying to create a DoH cross protocol query without a valid DoH unit");
×
664
  }
×
665

666
  auto dohUnit = getDUFromIDS(dq.ids);
62✔
667
  if (&dq.ids != &dohUnit->ids) {
62✔
668
   dohUnit->ids = std::move(dq.ids);
2✔
669
  }
2✔
670

671
  dohUnit->ids.origID = dq.getHeader()->id;
62✔
672

673
  if (!isResponse) {
62✔
674
    if (dohUnit->query.data() != dq.getMutableData().data()) {
34!
675
      dohUnit->query = std::move(dq.getMutableData());
×
676
    }
×
677
  }
34✔
678
  else {
28✔
679
    if (dohUnit->response.data() != dq.getMutableData().data()) {
28✔
680
      dohUnit->response = std::move(dq.getMutableData());
14✔
681
    }
14✔
682
  }
28✔
683

684
  return std::make_unique<DoHCrossProtocolQuery>(std::move(dohUnit), isResponse);
62✔
685
}
62✔
686

687
/*
688
   We are not in the main DoH thread but in the DoH 'client' thread.
689
*/
690
static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
691
{
94✔
692
  const auto handleImmediateResponse = [inMainThread](DOHUnitUniquePtr&& dohUnit, const char* reason) {
94✔
693
    if (inMainThread) {
30!
694
      handleResponse(*dohUnit->dsc->dohFrontend, dohUnit->req, dohUnit->status_code, dohUnit->response, dohUnit->dsc->dohFrontend->d_customResponseHeaders, dohUnit->contentType, true);
×
695
      /* so the unique pointer is stored in the InternalState which itself is stored in the unique pointer itself. We likely need
696
         a better design, but for now let's just reset the internal one since we know it is no longer needed. */
697
      dohUnit->ids.du.reset();
×
698
    }
×
699
    else {
30✔
700
      sendDoHUnitToTheMainThread(std::move(dohUnit), reason);
30✔
701
    }
30✔
702
  };
30✔
703

704
  auto& ids = unit->ids;
94✔
705
  uint16_t queryId = 0;
94✔
706
  ComboAddress remote;
94✔
707

708
  try {
94✔
709
    if (unit->req == nullptr) {
94!
710
      // we got closed meanwhile. XXX small race condition here
711
      // but we should be fine as long as we don't touch dohUnit->req
712
      // outside of the main DoH thread
713
      unit->status_code = 500;
×
714
      handleImmediateResponse(std::move(unit), "DoH killed in flight");
×
715
      return;
×
716
    }
×
717

718
    remote = ids.origRemote;
94✔
719
    DOHServerConfig* dsc = unit->dsc;
94✔
720
    ClientState& clientState = *dsc->clientState;
94✔
721

722
    if (unit->query.size() < sizeof(dnsheader) || unit->query.size() > std::numeric_limits<uint16_t>::max()) {
94!
723
      ++dnsdist::metrics::g_stats.nonCompliantQueries;
×
724
      ++clientState.nonCompliantQueries;
×
725
      unit->status_code = 400;
×
726
      handleImmediateResponse(std::move(unit), "DoH non-compliant query");
×
727
      return;
×
728
    }
×
729

730
    ++clientState.queries;
94✔
731
    ++dnsdist::metrics::g_stats.queries;
94✔
732
    ids.queryRealTime.start();
94✔
733

734
    {
94✔
735
      /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
736
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
737
      const dnsheader_aligned dnsHeader(unit->query.data());
94✔
738

739
      if (!checkQueryHeaders(*dnsHeader, clientState)) {
94✔
740
        unit->status_code = 400;
1✔
741
        handleImmediateResponse(std::move(unit), "DoH invalid headers");
1✔
742
        return;
1✔
743
      }
1✔
744

745
      if (dnsHeader->qdcount == 0U) {
93!
746
        dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
×
747
          header.rcode = RCode::NotImp;
×
748
          header.qr = true;
×
749
          return true;
×
750
        });
×
751
        unit->response = std::move(unit->query);
×
752

753
        handleImmediateResponse(std::move(unit), "DoH empty query");
×
754
        return;
×
755
      }
×
756

757
      queryId = ntohs(dnsHeader->id);
93✔
758
    }
93✔
759

760
    {
×
761
      // if there was no EDNS, we add it with a large buffer size
762
      // so we can use UDP to talk to the backend.
763
      dnsheader_aligned dnsHeader(unit->query.data());
93✔
764
      if (dnsHeader.get()->arcount == 0U) {
93✔
765
        if (addEDNS(unit->query, 4096, false, 4096, 0)) {
87!
766
          ids.ednsAdded = true;
87✔
767
        }
87✔
768
      }
87✔
769
    }
93✔
770

771
    auto downstream = unit->downstream;
93✔
772
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
773
    ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), static_cast<int>(sizeof(dnsheader)), false, &ids.qtype, &ids.qclass);
93✔
774
    DNSQuestion dnsQuestion(ids, unit->query);
93✔
775
    const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader().get());
93✔
776
    ids.origFlags = *flags;
93✔
777
    ids.cs = &clientState;
93✔
778
    dnsQuestion.sni = std::move(unit->sni);
93✔
779
    ids.du = std::move(unit);
93✔
780
    auto result = processQuery(dnsQuestion, downstream);
93✔
781

782
    if (result == ProcessQueryResult::Drop) {
93✔
783
      unit = getDUFromIDS(ids);
1✔
784
      unit->status_code = 403;
1✔
785
      handleImmediateResponse(std::move(unit), "DoH dropped query");
1✔
786
      return;
1✔
787
    }
1✔
788
    if (result == ProcessQueryResult::Asynchronous) {
92✔
789
      return;
34✔
790
    }
34✔
791
    if (result == ProcessQueryResult::SendAnswer) {
58✔
792
      unit = getDUFromIDS(ids);
27✔
793
      if (unit->response.empty()) {
27✔
794
        unit->response = std::move(unit->query);
23✔
795
      }
23✔
796
      if (unit->response.size() >= sizeof(dnsheader) && unit->contentType.empty()) {
27✔
797
        dnsheader_aligned dnsHeader(unit->response.data());
24✔
798
        handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *(dnsHeader.get()), dnsdist::Protocol::DoH, dnsdist::Protocol::DoH, false);
24✔
799
      }
24✔
800
      handleImmediateResponse(std::move(unit), "DoH self-answered response");
27✔
801
      return;
27✔
802
    }
27✔
803

804
    unit = getDUFromIDS(ids);
31✔
805
    if (result != ProcessQueryResult::PassToBackend) {
31!
806
      unit->status_code = 500;
×
807
      handleImmediateResponse(std::move(unit), "DoH no backend available");
×
808
      return;
×
809
    }
×
810

811
    if (downstream == nullptr) {
31!
812
      unit->status_code = 502;
×
813
      handleImmediateResponse(std::move(unit), "DoH no backend available");
×
814
      return;
×
815
    }
×
816

817
    unit->downstream = downstream;
31✔
818

819
    if (downstream->isTCPOnly()) {
31✔
820
      std::string proxyProtocolPayload;
1✔
821
      /* we need to do this _before_ creating the cross protocol query because
822
         after that the buffer will have been moved */
823
      if (downstream->d_config.useProxyProtocol) {
1!
824
        proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion);
×
825
      }
×
826

827
      unit->ids.origID = htons(queryId);
1✔
828
      unit->tcp = true;
1✔
829

830
      /* this moves du->ids, careful! */
831
      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(unit), false);
1✔
832
      if (!cpq) {
1!
833
        // make linters happy
834
        return;
×
835
      }
×
836
      cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
1✔
837

838
      if (downstream->passCrossProtocolQuery(std::move(cpq))) {
1!
839
        return;
1✔
840
      }
1✔
841

842
      if (inMainThread) {
×
843
        // cpq is not altered if the call fails but linters are not smart enough to notice that
844
        if (cpq) {
×
845
          // NOLINTNEXTLINE(bugprone-use-after-move): cpq is not altered if the call fails
846
          unit = cpq->releaseDU();
×
847
        }
×
848
        unit->status_code = 502;
×
849
        handleImmediateResponse(std::move(unit), "DoH internal error");
×
850
      }
×
851
      else {
×
852
        // cpq is not altered if the call fails but linters are not smart enough to notice that
853
        if (cpq) {
×
854
          // NOLINTNEXTLINE(bugprone-use-after-move): cpq is not altered if the call fails
855
          cpq->handleInternalError();
×
856
        }
×
857
      }
×
858
      return;
×
859
    }
1✔
860

861
    auto& query = unit->query;
30✔
862
    ids.du = std::move(unit);
30✔
863
    if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dnsQuestion, query)) {
30✔
864
      unit = getDUFromIDS(ids);
1✔
865
      unit->status_code = 502;
1✔
866
      handleImmediateResponse(std::move(unit), "DoH internal error");
1✔
867
      return;
1✔
868
    }
1✔
869
  }
30✔
870
  catch (const std::exception& e) {
94✔
871
    vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
×
872
    unit->status_code = 500;
×
873
    handleImmediateResponse(std::move(unit), "DoH internal error");
×
874
    return;
×
875
  }
×
876
}
94✔
877

878
/* called when a HTTP response is about to be sent, from the main DoH thread */
879
static void on_response_ready_cb(struct st_h2o_filter_t *self, h2o_req_t *req, h2o_ostream_t **slot)
880
{
109✔
881
  (void)self;
109✔
882
  if (req == nullptr) {
109!
883
    return;
×
884
  }
×
885

886
  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API
887
  auto* dsc = static_cast<DOHServerConfig*>(req->conn->ctx->storage.entries[0].data);
109✔
888

889
  DOHFrontend::HTTPVersionStats* stats = nullptr;
109✔
890
  if (req->version < 0x200) {
109✔
891
    /* HTTP 1.x */
892
    stats = &dsc->dohFrontend->d_http1Stats;
5✔
893
  }
5✔
894
  else {
104✔
895
    /* HTTP 2.0 */
896
    stats = &dsc->dohFrontend->d_http2Stats;
104✔
897
  }
104✔
898

899
  switch (req->res.status) {
109✔
900
  case 200:
84✔
901
    ++stats->d_nb200Responses;
84✔
902
    break;
84✔
903
  case 400:
11✔
904
    ++stats->d_nb400Responses;
11✔
905
    break;
11✔
906
  case 403:
2✔
907
    ++stats->d_nb403Responses;
2✔
908
    break;
2✔
909
  case 500:
×
910
    ++stats->d_nb500Responses;
×
911
    break;
×
912
  case 502:
7✔
913
    ++stats->d_nb502Responses;
7✔
914
    break;
7✔
915
  default:
5✔
916
    ++stats->d_nbOtherResponses;
5✔
917
    break;
5✔
918
  }
109✔
919

920
  h2o_setup_next_ostream(req, slot);
109✔
921
}
109✔
922

923
/* this is called by h2o when our request dies.
924
   We use this to signal to the 'du' that this req is no longer alive */
925
static void on_generator_dispose(void *_self)
926
{
94✔
927
  auto* dohUnit = static_cast<DOHUnit**>(_self);
94✔
928
  if (*dohUnit != nullptr) { // if nullptr, on_dnsdist cleaned up dohUnit already
94!
929
    (*dohUnit)->self = nullptr;
×
930
    (*dohUnit)->req = nullptr;
×
931
  }
×
932
}
94✔
933

934
/* This executes in the main DoH thread.
935
   We allocate a DOHUnit and send it to dnsdistclient() function in the doh client thread
936
   via a pipe */
937
static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_req_t* req, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, std::string&& path)
938
{
96✔
939
  auto* conn = getConnectionFromQuery(req);
96✔
940

941
  try {
96✔
942
    /* we only parse it there as a sanity check, we will parse it again later */
943
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
944
    DNSPacketMangler mangler(reinterpret_cast<char*>(query.data()), query.size());
96✔
945
    mangler.skipDomainName();
96✔
946
    mangler.skipBytes(4);
96✔
947

948
    /* we are doing quite some copies here, sorry about that,
949
       but we can't keep accessing the req object once we are in a different thread
950
       because the request might get killed by h2o at pretty much any time */
951
    auto dohUnit = std::make_unique<DOHUnit>(std::move(query), std::move(path), std::string(req->authority.base, req->authority.len));
96✔
952
    dohUnit->dsc = dsc;
96✔
953
    dohUnit->req = req;
96✔
954
    dohUnit->ids.origDest = local;
96✔
955
    dohUnit->ids.origRemote = remote;
96✔
956
    dohUnit->ids.protocol = dnsdist::Protocol::DoH;
96✔
957
    dohUnit->responseSender = &dsc->d_responseSender;
96✔
958
    if (req->scheme != nullptr) {
96✔
959
      dohUnit->scheme = std::string(req->scheme->name.base, req->scheme->name.len);
94✔
960
    }
94✔
961
    dohUnit->query_at = req->query_at;
96✔
962

963
    if (dsc->dohFrontend->d_keepIncomingHeaders) {
96✔
964
      dohUnit->headers = std::make_unique<std::unordered_map<std::string, std::string>>();
22✔
965
      dohUnit->headers->reserve(req->headers.size);
22✔
966
      for (size_t i = 0; i < req->headers.size; ++i) {
88✔
967
        // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API
968
        (*dohUnit->headers)[std::string(req->headers.entries[i].name->base, req->headers.entries[i].name->len)] = std::string(req->headers.entries[i].value.base, req->headers.entries[i].value.len);
66✔
969
      }
66✔
970
    }
22✔
971

972
    if (conn != nullptr) {
96✔
973
      ++conn->d_concurrentStreams;
94✔
974
    }
94✔
975
#ifdef HAVE_H2O_SOCKET_GET_SSL_SERVER_NAME
976
    h2o_socket_t* sock = req->conn->callbacks->get_socket(req->conn);
977
    const char * sni = h2o_socket_get_ssl_server_name(sock);
978
    if (sni != nullptr) {
979
      dohUnit->sni = sni;
980
    }
981
#endif /* HAVE_H2O_SOCKET_GET_SSL_SERVER_NAME */
982
    dohUnit->self = static_cast<DOHUnit**>(h2o_mem_alloc_shared(&req->pool, sizeof(*self), on_generator_dispose));
96✔
983
    *(dohUnit->self) = dohUnit.get();
96✔
984

985
#ifdef USE_SINGLE_ACCEPTOR_THREAD
986
    processDOHQuery(std::move(dohUnit), true);
987
#else /* USE_SINGLE_ACCEPTOR_THREAD */
988
    try {
96✔
989
      if (!dsc->d_querySender.send(std::move(dohUnit))) {
96!
990
        ++dnsdist::metrics::g_stats.dohQueryPipeFull;
×
991
        vinfolog("Unable to pass a DoH query to the DoH worker thread because the pipe is full");
×
992
        if (conn != nullptr) {
×
993
          --conn->d_concurrentStreams;
×
994
        }
×
995
        h2o_send_error_500(req, "Internal Server Error", "Internal Server Error", 0);
×
996
      }
×
997
    }
96✔
998
    catch (...) {
96✔
999
      vinfolog("Unable to pass a DoH query to the DoH worker thread because we couldn't write to the pipe: %s", stringerror());
×
1000
      if (conn != nullptr) {
×
1001
        --conn->d_concurrentStreams;
×
1002
      }
×
1003
      h2o_send_error_500(req, "Internal Server Error", "Internal Server Error", 0);
×
1004
    }
×
1005
#endif /* USE_SINGLE_ACCEPTOR_THREAD */
96✔
1006
  }
96✔
1007
  catch (const std::exception& e) {
96✔
1008
    vinfolog("Had error parsing DoH DNS packet from %s: %s", remote.toStringWithPort(), e.what());
2!
1009
    if (conn != nullptr) {
2!
1010
      --conn->d_concurrentStreams;
2✔
1011
    }
2✔
1012
    h2o_send_error_400(req, "Bad Request", "The DNS query could not be parsed", 0);
2✔
1013
  }
2✔
1014
}
96✔
1015

1016
/* can only be called from the main DoH thread */
1017
static bool getHTTPHeaderValue(const h2o_req_t* req, const std::string& headerName, std::string_view& value)
1018
{
5✔
1019
  bool found = false;
5✔
1020
  /* early versions of boost::string_ref didn't have the ability to compare to string */
1021
  std::string_view headerNameView(headerName);
5✔
1022

1023
  for (size_t i = 0; i < req->headers.size; ++i) {
20✔
1024
    // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API
1025
    if (std::string_view(req->headers.entries[i].name->base, req->headers.entries[i].name->len) == headerNameView) {
15✔
1026
      // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API
1027
      value = std::string_view(req->headers.entries[i].value.base, req->headers.entries[i].value.len);
4✔
1028
      /* don't stop there, we might have more than one header with the same name, and we want the last one */
1029
      found = true;
4✔
1030
    }
4✔
1031
  }
15✔
1032

1033
  return found;
5✔
1034
}
5✔
1035

1036
/* can only be called from the main DoH thread */
1037
static std::optional<ComboAddress> processForwardedForHeader(const h2o_req_t* req, const ComboAddress& remote)
1038
{
5✔
1039
  static const std::string headerName = "x-forwarded-for";
5✔
1040
  std::string_view value;
5✔
1041

1042
  if (getHTTPHeaderValue(req, headerName, value)) {
5✔
1043
    try {
4✔
1044
      auto pos = value.rfind(',');
4✔
1045
      if (pos != std::string_view::npos) {
4✔
1046
        ++pos;
2✔
1047
        for (; pos < value.size() && value[pos] == ' '; ++pos)
4!
1048
        {
2✔
1049
        }
2✔
1050

1051
        if (pos < value.size()) {
2!
1052
          value = value.substr(pos);
2✔
1053
        }
2✔
1054
      }
2✔
1055
      return ComboAddress(std::string(value));
4✔
1056
    }
4✔
1057
    catch (const std::exception& e) {
4✔
1058
      vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.what());
×
1059
    }
×
1060
    catch (const PDNSException& e) {
4✔
1061
      vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.reason);
×
1062
    }
×
1063
  }
4✔
1064

1065
  return std::nullopt;
1✔
1066
}
5✔
1067

1068
/*
1069
  A query has been parsed by h2o, this executes in the main DoH thread.
1070
  For GET, the base64url-encoded payload is in the 'dns' parameter, which might be the first parameter, or not.
1071
  For POST, the payload is the payload.
1072
 */
1073
static int doh_handler(h2o_handler_t *self, h2o_req_t *req)
1074
{
109✔
1075
  try {
109✔
1076
    if (req->conn->ctx->storage.size == 0) {
109!
1077
      return 0; // although we might was well crash on this
×
1078
    }
×
1079
    // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API
1080
    auto* dsc = static_cast<DOHServerConfig*>(req->conn->ctx->storage.entries[0].data);
109✔
1081
    auto* connPtr = getConnectionFromQuery(req);
109✔
1082
    if (connPtr == nullptr) {
109!
1083
      return 0;
×
1084
    }
×
1085
    auto& conn = *connPtr;
109✔
1086
    if (conn.d_concurrentStreams >= dnsdist::doh::MAX_INCOMING_CONCURRENT_STREAMS) {
109!
1087
      vinfolog("Too many concurrent streams on connection from %d", conn.d_remote.toStringWithPort());
×
1088
      return 0;
×
1089
    }
×
1090

1091
    ++conn.d_nbQueries;
109✔
1092

1093
    h2o_socket_t* sock = req->conn->callbacks->get_socket(req->conn);
109✔
1094
    if (conn.d_nbQueries == 1) {
109!
1095
      if (h2o_socket_get_ssl_session_reused(sock) == 0) {
109✔
1096
        ++dsc->clientState->tlsNewSessions;
107✔
1097
      }
107✔
1098
      else {
2✔
1099
        ++dsc->clientState->tlsResumptions;
2✔
1100
      }
2✔
1101

1102
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API
1103
      h2o_socket_getsockname(sock, reinterpret_cast<struct sockaddr*>(&conn.d_local));
109✔
1104
    }
109✔
1105

1106
    auto remote = conn.d_remote;
109✔
1107
    if (dsc->dohFrontend->d_trustForwardedForHeader) {
109✔
1108
      auto newRemote = processForwardedForHeader(req, remote);
5✔
1109
      if (newRemote) {
5✔
1110
        remote = *newRemote;
4✔
1111
      }
4✔
1112
    }
5✔
1113

1114
    if (!dnsdist::configuration::getCurrentRuntimeConfiguration().d_ACL.match(remote)) {
109✔
1115
      ++dnsdist::metrics::g_stats.aclDrops;
1✔
1116
      vinfolog("Query from %s (DoH) dropped because of ACL", remote.toStringWithPort());
1!
1117
      h2o_send_error_403(req, "Forbidden", "DoH query not allowed because of ACL", 0);
1✔
1118
      return 0;
1✔
1119
    }
1✔
1120

1121
    if (const auto* tlsversion = h2o_socket_get_ssl_protocol_version(sock)) {
108✔
1122
      if (strcmp(tlsversion, "TLSv1.0") == 0) {
106!
1123
        ++dsc->clientState->tls10queries;
×
1124
      }
×
1125
      else if (strcmp(tlsversion, "TLSv1.1") == 0) {
106!
1126
        ++dsc->clientState->tls11queries;
×
1127
      }
×
1128
      else if (strcmp(tlsversion, "TLSv1.2") == 0) {
106!
1129
        ++dsc->clientState->tls12queries;
×
1130
      }
×
1131
      else if (strcmp(tlsversion, "TLSv1.3") == 0) {
106!
1132
        ++dsc->clientState->tls13queries;
106✔
1133
      }
106✔
1134
      else {
×
1135
        ++dsc->clientState->tlsUnknownqueries;
×
1136
      }
×
1137
    }
106✔
1138

1139
    if (dsc->dohFrontend->d_exactPathMatching) {
108✔
1140
      const std::string_view pathOnly(req->path_normalized.base, req->path_normalized.len);
106✔
1141
      if (dsc->paths.count(pathOnly) == 0) {
106✔
1142
        h2o_send_error_404(req, "Not Found", "there is no endpoint configured for this path", 0);
2✔
1143
        return 0;
2✔
1144
      }
2✔
1145
    }
106✔
1146

1147
    // would be nice to be able to use a std::string_view there,
1148
    // but regex (called by matches() internally) requires a null-terminated string
1149
    string path(req->path.base, req->path.len);
106✔
1150
    /* the responses map can be updated at runtime, so we need to take a copy of
1151
       the shared pointer, increasing the reference counter */
1152
    auto responsesMap = dsc->dohFrontend->d_responsesMap;
106✔
1153
    /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */
1154
    const size_t maxAdditionalSizeForEDNS = 35U;
106✔
1155
    if (responsesMap) {
106✔
1156
      for (const auto& entry : *responsesMap) {
27✔
1157
        if (entry->matches(path)) {
27✔
1158
          const auto& customHeaders = entry->getHeaders();
2✔
1159
          ++conn.d_concurrentStreams;
2✔
1160
          handleResponse(*dsc->dohFrontend, req, entry->getStatusCode(), entry->getContent(), customHeaders ? *customHeaders : dsc->dohFrontend->d_customResponseHeaders, std::string(), false);
2!
1161
          return 0;
2✔
1162
        }
2✔
1163
      }
27✔
1164
    }
27✔
1165

1166
    if (h2o_memis(req->method.base, req->method.len, H2O_STRLIT("POST")) != 0) {
104✔
1167
      ++dsc->dohFrontend->d_postqueries;
6✔
1168
      if (req->version >= 0x0200) {
6!
1169
        ++dsc->dohFrontend->d_http2Stats.d_nbQueries;
6✔
1170
      }
6✔
1171
      else {
×
1172
        ++dsc->dohFrontend->d_http1Stats.d_nbQueries;
×
1173
      }
×
1174

1175
      PacketBuffer query;
6✔
1176
      /* We reserve a few additional bytes to be able to add EDNS later */
1177
      query.reserve(req->entity.len + maxAdditionalSizeForEDNS);
6✔
1178
      query.resize(req->entity.len);
6✔
1179
      memcpy(query.data(), req->entity.base, req->entity.len);
6✔
1180
      doh_dispatch_query(dsc, self, req, std::move(query), conn.d_local, remote, std::move(path));
6✔
1181
    }
6✔
1182
    else if(req->query_at != SIZE_MAX && (req->path.len - req->query_at > 5)) {
98!
1183
      auto pos = path.find("?dns=");
92✔
1184
      if (pos == string::npos) {
92✔
1185
        pos = path.find("&dns=");
1✔
1186
      }
1✔
1187
      if (pos != string::npos) {
92✔
1188
        // need to base64url decode this
1189
        string sdns(path.substr(pos+5));
91✔
1190
        std::replace(sdns.begin(), sdns.end(), '-', '+');
91✔
1191
        std::replace(sdns.begin(), sdns.end(), '_', '/');
91✔
1192
        // re-add padding that may have been missing
1193
        switch (sdns.size() % 4) {
91✔
1194
        case 2:
28✔
1195
          sdns.append(2, '=');
28✔
1196
          break;
28✔
1197
        case 3:
39✔
1198
          sdns.append(1, '=');
39✔
1199
          break;
39✔
1200
        }
91✔
1201

1202
        PacketBuffer decoded;
91✔
1203

1204
        /* rough estimate so we hopefully don't need a new allocation later */
1205
        /* We reserve at few additional bytes to be able to add EDNS later */
1206
        const size_t estimate = ((sdns.size() * 3) / 4);
91✔
1207
        decoded.reserve(estimate + maxAdditionalSizeForEDNS);
91✔
1208
        if(B64Decode(sdns, decoded) < 0) {
91✔
1209
          h2o_send_error_400(req, "Bad Request", "Unable to decode BASE64-URL", 0);
1✔
1210
          ++dsc->dohFrontend->d_badrequests;
1✔
1211
          return 0;
1✔
1212
        }
1✔
1213

1214
        ++dsc->dohFrontend->d_getqueries;
90✔
1215
        if (req->version >= 0x0200) {
90!
1216
          ++dsc->dohFrontend->d_http2Stats.d_nbQueries;
90✔
1217
        }
90✔
1218
        else {
×
1219
          ++dsc->dohFrontend->d_http1Stats.d_nbQueries;
×
1220
        }
×
1221

1222
        doh_dispatch_query(dsc, self, req, std::move(decoded), conn.d_local, remote, std::move(path));
90✔
1223
      }
90✔
1224
      else
1✔
1225
      {
1✔
1226
        vinfolog("HTTP request without DNS parameter: %s", req->path.base);
1!
1227
        h2o_send_error_400(req, "Bad Request", "Unable to find the DNS parameter", 0);
1✔
1228
        ++dsc->dohFrontend->d_badrequests;
1✔
1229
        return 0;
1✔
1230
      }
1✔
1231
    }
92✔
1232
    else {
6✔
1233
      h2o_send_error_400(req, "Bad Request", "Unable to parse the request", 0);
6✔
1234
      ++dsc->dohFrontend->d_badrequests;
6✔
1235
    }
6✔
1236
    return 0;
102✔
1237
  }
104✔
1238
  catch (const std::exception& e) {
109✔
1239
    vinfolog("DOH Handler function failed with error: '%s'", e.what());
×
1240
    return 0;
×
1241
  }
×
1242
}
109✔
1243

1244
const std::unordered_map<std::string, std::string>& DOHUnit::getHTTPHeaders() const
1245
{
20✔
1246
  if (!headers) {
20!
1247
    static const HeadersMap empty{};
×
1248
    return empty;
×
1249
  }
×
1250
  return *headers;
20✔
1251
}
20✔
1252

1253
std::string DOHUnit::getHTTPPath() const
1254
{
33✔
1255
  if (query_at == SIZE_MAX) {
33✔
1256
    return path;
8✔
1257
  }
8✔
1258
  return {path, 0, query_at};
25✔
1259
}
33✔
1260

1261
const std::string& DOHUnit::getHTTPHost() const
1262
{
4✔
1263
  return host;
4✔
1264
}
4✔
1265

1266
const std::string& DOHUnit::getHTTPScheme() const
1267
{
4✔
1268
  return scheme;
4✔
1269
}
4✔
1270

1271
std::string DOHUnit::getHTTPQueryString() const
1272
{
4✔
1273
  if (query_at == SIZE_MAX) {
4✔
1274
    return {};
2✔
1275
  }
2✔
1276
  return path.substr(query_at);
2✔
1277
}
4✔
1278

1279
void DOHUnit::setHTTPResponse(uint16_t statusCode, PacketBuffer&& body_, const std::string& contentType_)
1280
{
4✔
1281
  status_code = statusCode;
4✔
1282
  response = std::move(body_);
4✔
1283
  if (!response.empty() && statusCode >= 400) {
4!
1284
    // we need to make sure it's null-terminated */
1285
    if (response.at(response.size() - 1) != 0) {
×
1286
      response.push_back(0);
×
1287
    }
×
1288
  }
×
1289

1290
  contentType = contentType_;
4✔
1291
}
4✔
1292

1293
#ifndef USE_SINGLE_ACCEPTOR_THREAD
1294
/* query has been parsed by h2o, which called doh_handler() in the main DoH thread.
1295
   In order not to block for long, doh_handler() called doh_dispatch_query() which allocated
1296
   a DOHUnit object and passed it to us */
1297
static void dnsdistclient(pdns::channel::Receiver<DOHUnit>&& receiver)
1298
{
23✔
1299
  setThreadName("dnsdist/doh-cli");
23✔
1300

1301
  for(;;) {
117✔
1302
    try {
117✔
1303
      auto tmp = receiver.receive();
117✔
1304
      if (!tmp) {
117!
1305
        continue;
×
1306
      }
×
1307
      auto dohUnit = std::move(*tmp);
117✔
1308
      /* we are not in the main DoH thread anymore, so there is a real risk of
1309
         a race condition where h2o kills the query while we are processing it,
1310
         so we can't touch the content of dohUnit->req until we are back into the
1311
         main DoH thread */
1312
      if (dohUnit->req == nullptr) {
117!
1313
        // it got killed in flight already
1314
        dohUnit->self = nullptr;
×
1315
        continue;
×
1316
      }
×
1317

1318
      processDOHQuery(std::move(dohUnit), false);
117✔
1319
    }
117✔
1320
    catch (const std::exception& e) {
117✔
1321
      vinfolog("Error while processing query received over DoH: %s", e.what());
×
1322
    }
×
1323
    catch (...) {
117✔
1324
      vinfolog("Unspecified error while processing query received over DoH");
×
1325
    }
×
1326
  }
117✔
1327
}
23✔
1328
#endif /* USE_SINGLE_ACCEPTOR_THREAD */
1329

1330
/* Called in the main DoH thread if h2o finds that dnsdist gave us an answer by writing into
1331
   the response channel so from:
1332
   - handleDOHTimeout() when we did not get a response fast enough (called
1333
     either from the health check thread (active) or from the frontend ones (reused))
1334
   - dnsdistclient (error 500 because processDOHQuery() returned a negative value)
1335
   - processDOHQuery (self-answered queries)
1336
   */
1337
static void on_dnsdist(h2o_socket_t *listener, const char *err)
1338
{
98✔
1339
  (void)err;
98✔
1340
  /* we want to read as many responses from the pipe as possible before
1341
     giving up. Even if we are overloaded and fighting with the DoH connections
1342
     for the CPU, the first thing we need to do is to send responses to free slots
1343
     anyway, otherwise queries and responses are piling up in our pipes, consuming
1344
     memory and likely coming up too late after the client has gone away */
1345
  auto* dsc = static_cast<DOHServerConfig*>(listener->data);
98✔
1346
  while (true) {
196✔
1347
    DOHUnitUniquePtr dohUnit{nullptr};
196✔
1348
    try {
196✔
1349
      auto tmp = dsc->d_responseReceiver.receive();
196✔
1350
      if (!tmp) {
196✔
1351
        return;
98✔
1352
      }
98✔
1353
      dohUnit = std::move(*tmp);
98✔
1354
    }
98✔
1355
    catch (const std::exception& e) {
196✔
1356
      warnlog("Error reading a DOH internal response: %s", e.what());
×
1357
      return;
×
1358
    }
×
1359

1360
    if (dohUnit->req == nullptr) { // it got killed in flight
98!
1361
      dohUnit->self = nullptr;
×
1362
      continue;
×
1363
    }
×
1364

1365
    if (!dohUnit->tcp &&
98✔
1366
        dohUnit->truncated &&
98✔
1367
        dohUnit->query.size() > dohUnit->ids.d_proxyProtocolPayloadSize &&
98!
1368
        (dohUnit->query.size() - dohUnit->ids.d_proxyProtocolPayloadSize) > sizeof(dnsheader)) {
98!
1369
      /* restoring the original ID */
1370
      dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&dohUnit->query.at(dohUnit->ids.d_proxyProtocolPayloadSize), [oldID=dohUnit->ids.origID](dnsheader& header) {
4✔
1371
        header.id = oldID;
4✔
1372
        return true;
4✔
1373
      });
4✔
1374
      dohUnit->ids.forwardedOverUDP = false;
4✔
1375
      dohUnit->tcp = true;
4✔
1376
      dohUnit->truncated = false;
4✔
1377
      dohUnit->response.clear();
4✔
1378

1379
      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(dohUnit), false);
4✔
1380

1381
      if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
4!
1382
        continue;
4✔
1383
      }
4✔
1384
      vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP");
×
1385
      continue;
×
1386
    }
4✔
1387

1388
    if (dohUnit->self != nullptr) {
94!
1389
      // we are back in the h2o main thread now, so we don't risk
1390
      // a race (h2o killing the query) when accessing dohUnit->req anymore
1391
      *dohUnit->self = nullptr; // so we don't clean up again in on_generator_dispose
94✔
1392
      dohUnit->self = nullptr;
94✔
1393
    }
94✔
1394

1395
    handleResponse(*dsc->dohFrontend, dohUnit->req, dohUnit->status_code, dohUnit->response, dsc->dohFrontend->d_customResponseHeaders, dohUnit->contentType, true);
94✔
1396
  }
94✔
1397
}
98✔
1398

1399
/* called when a TCP connection has been accepted, the TLS session has not been established */
1400
static void on_accept(h2o_socket_t *listener, const char *err)
1401
{
137✔
1402
  auto* dsc = static_cast<DOHServerConfig*>(listener->data);
137✔
1403

1404
  if (err != nullptr) {
137!
1405
    return;
×
1406
  }
×
1407

1408
  h2o_socket_t* sock = h2o_evloop_socket_accept(listener);
137✔
1409
  if (sock == nullptr) {
137!
1410
    return;
×
1411
  }
×
1412

1413
  const int descriptor = h2o_socket_get_fd(sock);
137✔
1414
  if (descriptor == -1) {
137!
1415
    h2o_socket_close(sock);
×
1416
    return;
×
1417
  }
×
1418

1419
  ComboAddress remote;
137✔
1420
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API
1421
  if (h2o_socket_getpeername(sock, reinterpret_cast<struct sockaddr*>(&remote)) == 0) {
137!
1422
    vinfolog("Dropping DoH connection because we could not retrieve the remote host");
×
1423
    h2o_socket_close(sock);
×
1424
    return;
×
1425
  }
×
1426

1427
  if (dsc->dohFrontend->d_earlyACLDrop && !dsc->dohFrontend->d_trustForwardedForHeader && !dnsdist::configuration::getCurrentRuntimeConfiguration().d_ACL.match(remote)) {
137!
1428
    ++dnsdist::metrics::g_stats.aclDrops;
1✔
1429
    vinfolog("Dropping DoH connection from %s because of ACL", remote.toStringWithPort());
1!
1430
    h2o_socket_close(sock);
1✔
1431
    return;
1✔
1432
  }
1✔
1433

1434
  auto connectionResult = dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote, false);
136✔
1435
  if (connectionResult == dnsdist::IncomingConcurrentTCPConnectionsManager::NewConnectionResult::Denied) {
136✔
1436
    h2o_socket_close(sock);
1✔
1437
    return;
1✔
1438
  }
1✔
1439

1440
  auto concurrentConnections = ++dsc->clientState->tcpCurrentConnections;
135✔
1441
  if (dsc->clientState->d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > dsc->clientState->d_tcpConcurrentConnectionsLimit) {
135✔
1442
    --dsc->clientState->tcpCurrentConnections;
1✔
1443
    h2o_socket_close(sock);
1✔
1444
    return;
1✔
1445
  }
1✔
1446

1447
  if (concurrentConnections > dsc->clientState->tcpMaxConcurrentConnections.load()) {
134✔
1448
    dsc->clientState->tcpMaxConcurrentConnections.store(concurrentConnections);
28✔
1449
  }
28✔
1450

1451
  auto& conn = t_conns[descriptor];
134✔
1452

1453
  gettimeofday(&conn.d_connectionStartTime, nullptr);
134✔
1454
  conn.d_nbQueries = 0;
134✔
1455
  conn.d_acceptCtx = std::atomic_load_explicit(&dsc->accept_ctx, std::memory_order_acquire);
134✔
1456
  conn.d_desc = descriptor;
134✔
1457
  conn.d_remote = remote;
134✔
1458

1459
  sock->on_close.cb = on_socketclose;
134✔
1460
  sock->on_close.data = &conn;
134✔
1461
  sock->data = dsc;
134✔
1462

1463
  ++dsc->dohFrontend->d_httpconnects;
134✔
1464

1465
  h2o_accept(conn.d_acceptCtx->get(), sock);
134✔
1466
}
134✔
1467

1468
static int create_listener(std::shared_ptr<DOHServerConfig>& dsc, int descriptor)
1469
{
23✔
1470
  dsc->h2o_socket = std::unique_ptr<h2o_socket_t, decltype(&h2o_socket_close)>{h2o_evloop_socket_create(dsc->h2o_ctx.loop, descriptor, H2O_SOCKET_FLAG_DONT_READ), &h2o_socket_close};
23✔
1471
  dsc->h2o_socket->data = dsc.get();
23✔
1472
  h2o_socket_read_start(dsc->h2o_socket.get(), on_accept);
23✔
1473

1474
  return 0;
23✔
1475
}
23✔
1476

1477
#ifndef DISABLE_OCSP_STAPLING
1478
static int ocsp_stapling_callback(SSL* ssl, void* arg)
1479
{
2✔
1480
  if (ssl == nullptr || arg == nullptr) {
2!
1481
    return SSL_TLSEXT_ERR_NOACK;
×
1482
  }
×
1483
  const auto* ocspMap = static_cast<std::map<int, std::string>*>(arg);
2✔
1484
  return libssl_ocsp_stapling_callback(ssl, *ocspMap);
2✔
1485
}
2✔
1486
#endif /* DISABLE_OCSP_STAPLING */
1487

1488
#if OPENSSL_VERSION_MAJOR >= 3
1489
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays): OpenSSL API
1490
static int ticket_key_callback(SSL* sslContext, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* ivector, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx, int enc)
1491
#else
1492
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays): OpenSSL API
1493
static int ticket_key_callback(SSL *sslContext, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* ivector, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx, int enc)
1494
#endif
1495
{
254✔
1496
  auto* ctx = static_cast<DOHAcceptContext*>(libssl_get_ticket_key_callback_data(sslContext));
254✔
1497
  if (ctx == nullptr || !ctx->d_ticketKeys) {
254!
1498
    return -1;
×
1499
  }
×
1500

1501
  ctx->handleTicketsKeyRotation();
254✔
1502

1503
  auto ret = libssl_ticket_key_callback(sslContext, *ctx->d_ticketKeys, keyName, ivector, ectx, hctx, enc);
254✔
1504
  if (enc == 0) {
254✔
1505
    if (ret == 0) {
14✔
1506
      ++ctx->d_cs->tlsUnknownTicketKey;
3✔
1507
    }
3✔
1508
    else if (ret == 2) {
11✔
1509
      ++ctx->d_cs->tlsInactiveTicketKey;
4✔
1510
    }
4✔
1511
  }
14✔
1512

1513
  return ret;
254✔
1514
}
254✔
1515

1516
static void setupTLSContext(DOHAcceptContext& acceptCtx,
1517
                            TLSConfig& tlsConfig,
1518
                            TLSErrorCounters& counters)
1519
{
24✔
1520
  if (tlsConfig.d_ciphers.empty()) {
24✔
1521
    tlsConfig.d_ciphers = DOH_DEFAULT_CIPHERS.data();
22✔
1522
  }
22✔
1523

1524
  auto [ctx, warnings] = libssl_init_server_context_no_sni(tlsConfig, acceptCtx.d_ocspResponses);
24✔
1525
  for (const auto& warning : warnings) {
24✔
1526
    warnlog("%s", warning);
1✔
1527
  }
1✔
1528

1529
  if (tlsConfig.d_enableTickets && tlsConfig.d_numberOfTicketsKeys > 0) {
24!
1530
    acceptCtx.d_ticketKeys = std::make_unique<OpenSSLTLSTicketKeysRing>(tlsConfig.d_numberOfTicketsKeys);
23✔
1531
#if OPENSSL_VERSION_MAJOR >= 3
23✔
1532
    SSL_CTX_set_tlsext_ticket_key_evp_cb(ctx.get(), &ticket_key_callback);
23✔
1533
#else
1534
    SSL_CTX_set_tlsext_ticket_key_cb(ctx.get(), &ticket_key_callback);
1535
#endif
1536
    libssl_set_ticket_key_callback_data(ctx.get(), &acceptCtx);
23✔
1537
  }
23✔
1538

1539
#ifndef DISABLE_OCSP_STAPLING
24✔
1540
  if (!acceptCtx.d_ocspResponses.empty()) {
24✔
1541
    SSL_CTX_set_tlsext_status_cb(ctx.get(), &ocsp_stapling_callback);
3✔
1542
    SSL_CTX_set_tlsext_status_arg(ctx.get(), &acceptCtx.d_ocspResponses);
3✔
1543
  }
3✔
1544
#endif /* DISABLE_OCSP_STAPLING */
24✔
1545

1546
  libssl_set_error_counters_callback(*ctx.get(), &counters);
24✔
1547

1548
  if (!tlsConfig.d_keyLogFile.empty()) {
24!
1549
    acceptCtx.d_keyLogFile = libssl_set_key_log_file(ctx.get(), tlsConfig.d_keyLogFile);
×
1550
  }
×
1551

1552
  h2o_ssl_register_alpn_protocols(ctx.get(), h2o_http2_alpn_protocols);
24✔
1553

1554
  acceptCtx.d_ticketsKeyRotationDelay = tlsConfig.d_ticketsKeyRotationDelay;
24✔
1555
  if (tlsConfig.d_ticketKeyFile.empty()) {
24!
1556
    acceptCtx.handleTicketsKeyRotation();
24✔
1557
  }
24✔
1558
  else {
×
1559
    acceptCtx.loadTicketsKeys(tlsConfig.d_ticketKeyFile);
×
1560
  }
×
1561

1562
  auto* nativeCtx = acceptCtx.get();
24✔
1563
  nativeCtx->ssl_ctx = ctx.release();
24✔
1564
}
24✔
1565

1566
static void setupAcceptContext(DOHAcceptContext& ctx, DOHServerConfig& dsc, bool setupTLS)
1567
{
25✔
1568
  auto* nativeCtx = ctx.get();
25✔
1569
  nativeCtx->ctx = &dsc.h2o_ctx;
25✔
1570
  nativeCtx->hosts = dsc.h2o_config.hosts;
25✔
1571
  auto dohFrontend = std::atomic_load_explicit(&dsc.dohFrontend, std::memory_order_acquire);
25✔
1572
  ctx.d_ticketsKeyRotationDelay = dohFrontend->d_tlsContext->d_tlsConfig.d_ticketsKeyRotationDelay;
25✔
1573

1574
  if (setupTLS && dohFrontend->isHTTPS()) {
25!
1575
    try {
2✔
1576
      setupTLSContext(ctx,
2✔
1577
                      dohFrontend->d_tlsContext->d_tlsConfig,
2✔
1578
                      dohFrontend->d_tlsContext->d_tlsCounters);
2✔
1579
    }
2✔
1580
    catch (const std::runtime_error& e) {
2✔
1581
      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + "': " + e.what());
×
1582
    }
×
1583
  }
2✔
1584
  ctx.d_cs = dsc.clientState;
25✔
1585
}
25✔
1586

1587
static h2o_pathconf_t *register_handler(h2o_hostconf_t *hostconf, const char *path, int (*on_req)(h2o_handler_t *, h2o_req_t *))
1588
{
27✔
1589
  h2o_pathconf_t *pathconf = h2o_config_register_path(hostconf, path, 0);
27✔
1590
  if (pathconf == nullptr) {
27!
1591
    return pathconf;
×
1592
  }
×
1593
  h2o_filter_t *filter = h2o_create_filter(pathconf, sizeof(*filter));
27✔
1594
  if (filter != nullptr) {
27!
1595
    filter->on_setup_ostream = on_response_ready_cb;
27✔
1596
  }
27✔
1597

1598
  h2o_handler_t *handler = h2o_create_handler(pathconf, sizeof(*handler));
27✔
1599
  if (handler != nullptr) {
27!
1600
    handler->on_req = on_req;
27✔
1601
  }
27✔
1602

1603
  return pathconf;
27✔
1604
}
27✔
1605

1606
// this is the entrypoint from dnsdist.cc
1607
void dohThread(ClientState* clientState)
1608
{
23✔
1609
  try {
23✔
1610
    std::shared_ptr<DOHFrontend>& dohFrontend = clientState->dohFrontend;
23✔
1611
    auto& dsc = dohFrontend->d_dsc;
23✔
1612
    dsc->clientState = clientState;
23✔
1613
    std::atomic_store_explicit(&dsc->dohFrontend, clientState->dohFrontend, std::memory_order_release);
23✔
1614
    dsc->h2o_config.server_name = h2o_iovec_init(dohFrontend->d_serverTokens.c_str(), dohFrontend->d_serverTokens.size());
23✔
1615

1616
#ifndef USE_SINGLE_ACCEPTOR_THREAD
23✔
1617
    std::thread dnsdistThread(dnsdistclient, std::move(dsc->d_queryReceiver));
23✔
1618
    dnsdistThread.detach(); // gets us better error reporting
23✔
1619
#endif
23✔
1620

1621
    setThreadName("dnsdist/doh");
23✔
1622
    // I wonder if this registers an IP address.. I think it does
1623
    // this may mean we need to actually register a site "name" here and not the IP address
1624
    h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(dohFrontend->d_tlsContext->d_addr.toString().c_str(), dohFrontend->d_tlsContext->d_addr.toString().size()), 65535);
23✔
1625

1626
    dsc->paths = dohFrontend->d_urls;
23✔
1627
    for (const auto& url : dsc->paths) {
27✔
1628
      register_handler(hostconf, url.c_str(), doh_handler);
27✔
1629
    }
27✔
1630

1631
    h2o_context_init(&dsc->h2o_ctx, h2o_evloop_create(), &dsc->h2o_config);
23✔
1632

1633
    // in this complicated way we insert the DOHServerConfig pointer in there
1634
    h2o_vector_reserve(nullptr, &dsc->h2o_ctx.storage, 1);
23✔
1635
    // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API
1636
    dsc->h2o_ctx.storage.entries[0].data = dsc.get();
23✔
1637
    ++dsc->h2o_ctx.storage.size;
23✔
1638

1639
    auto sock = std::unique_ptr<h2o_socket_t, decltype(&h2o_socket_close)>{h2o_evloop_socket_create(dsc->h2o_ctx.loop, dsc->d_responseReceiver.getDescriptor(), H2O_SOCKET_FLAG_DONT_READ), &h2o_socket_close};
23✔
1640
    sock->data = dsc.get();
23✔
1641

1642
    // this listens to responses from dnsdist to turn into http responses
1643
    h2o_socket_read_start(sock.get(), on_dnsdist);
23✔
1644

1645
    setupAcceptContext(*dsc->accept_ctx, *dsc, false);
23✔
1646

1647
    if (create_listener(dsc, clientState->tcpFD) != 0) {
23!
1648
      throw std::runtime_error("DOH server failed to listen on " + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + ": " + stringerror(errno));
×
1649
    }
×
1650
    for (const auto& [addr, descriptor] : clientState->d_additionalAddresses) {
23!
1651
      if (create_listener(dsc, descriptor) != 0) {
×
1652
        throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + ": " + stringerror(errno));
×
1653
      }
×
1654
    }
×
1655

1656
    bool stop = false;
23✔
1657
    do {
1,110✔
1658
      int result = h2o_evloop_run(dsc->h2o_ctx.loop, INT32_MAX);
1,110✔
1659
      if (result == -1) {
1,110✔
1660
        if (errno != EINTR) {
23!
1661
          errlog("Error in the DoH event loop: %s", stringerror(errno));
×
1662
          stop = true;
×
1663
        }
×
1664
      }
23✔
1665
    }
1,110✔
1666
    while (!stop);
1,110✔
1667

1668
    h2o_evloop_destroy(dsc->h2o_ctx.loop);
23✔
1669
  }
23✔
1670
  catch (const std::exception& e) {
23✔
1671
    throw runtime_error("DOH thread failed to launch: " + std::string(e.what()));
×
1672
  }
×
1673
  catch (...) {
23✔
1674
    throw runtime_error("DOH thread failed to launch");
×
1675
  }
×
1676
}
23✔
1677

1678
void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, InternalQueryState&& state, [[maybe_unused]] const std::shared_ptr<DownstreamState>& downstream_)
1679
{
43✔
1680
  auto dohUnit = std::unique_ptr<DOHUnit>(this);
43✔
1681
  dohUnit->ids = std::move(state);
43✔
1682

1683
  {
43✔
1684
    dnsheader_aligned dnsHeader(udpResponse.data());
43✔
1685
    if (dnsHeader.get()->tc) {
43✔
1686
      dohUnit->truncated = true;
4✔
1687
    }
4✔
1688
  }
43✔
1689
  if (!dohUnit->truncated) {
43✔
1690
    DNSResponse dnsResponse(dohUnit->ids, udpResponse, dohUnit->downstream);
39✔
1691
    dnsheader cleartextDH{};
39✔
1692
    memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH));
39✔
1693

1694
    dnsResponse.ids.du = std::move(dohUnit);
39✔
1695
    if (!processResponse(udpResponse, dnsResponse, false)) {
39!
1696
      if (dnsResponse.ids.du) {
×
1697
        dohUnit = getDUFromIDS(dnsResponse.ids);
×
1698
        dohUnit->status_code = 503;
×
1699
        sendDoHUnitToTheMainThread(std::move(dohUnit), "Response dropped by rules");
×
1700
      }
×
1701
      return;
×
1702
    }
×
1703

1704
    if (dnsResponse.isAsynchronous()) {
39✔
1705
      return;
12✔
1706
    }
12✔
1707

1708
    dohUnit = getDUFromIDS(dnsResponse.ids);
27✔
1709
    dohUnit->response = std::move(udpResponse);
27✔
1710
    double udiff = dohUnit->ids.queryRealTime.udiff();
27✔
1711
    vinfolog("Got answer from %s, relayed to %s (https), took %f us", dohUnit->downstream->d_config.remote.toStringWithPort(), dohUnit->ids.origRemote.toStringWithPort(), udiff);
27✔
1712

1713
    handleResponseSent(dohUnit->ids, udiff, dnsResponse.ids.origRemote, dohUnit->downstream->d_config.remote, dohUnit->response.size(), cleartextDH, dohUnit->downstream->getProtocol(), true);
27✔
1714

1715
    ++dnsdist::metrics::g_stats.responses;
27✔
1716
    if (dohUnit->ids.cs != nullptr) {
27!
1717
      ++dohUnit->ids.cs->responses;
27✔
1718
    }
27✔
1719
  }
27✔
1720

1721
  sendDoHUnitToTheMainThread(std::move(dohUnit), "DoH response");
31✔
1722
}
31✔
1723

1724
void H2ODOHFrontend::rotateTicketsKey(time_t now)
1725
{
22✔
1726
  if (d_dsc && d_dsc->accept_ctx) {
22!
1727
    d_dsc->accept_ctx->rotateTicketsKey(now);
22✔
1728
  }
22✔
1729
}
22✔
1730

1731
void H2ODOHFrontend::loadTicketsKeys(const std::string& keyFile)
1732
{
6✔
1733
  if (d_dsc && d_dsc->accept_ctx) {
6!
1734
    d_dsc->accept_ctx->loadTicketsKeys(keyFile);
6✔
1735
  }
6✔
1736
}
6✔
1737

1738
void H2ODOHFrontend::handleTicketsKeyRotation()
1739
{
×
1740
  if (d_dsc && d_dsc->accept_ctx) {
×
1741
    d_dsc->accept_ctx->handleTicketsKeyRotation();
×
1742
  }
×
1743
}
×
1744

1745
std::string H2ODOHFrontend::getNextTicketsKeyRotation() const
1746
{
×
1747
  if (d_dsc && d_dsc->accept_ctx) {
×
1748
    return std::to_string(d_dsc->accept_ctx->getNextTicketsKeyRotation());
×
1749
  }
×
1750
  return {};
×
1751
}
×
1752

1753
size_t H2ODOHFrontend::getTicketsKeysCount()
1754
{
×
1755
  size_t res = 0;
×
1756
  if (d_dsc && d_dsc->accept_ctx) {
×
1757
    res = d_dsc->accept_ctx->getTicketsKeysCount();
×
1758
  }
×
1759
  return res;
×
1760
}
×
1761

1762
void H2ODOHFrontend::reloadCertificates()
1763
{
2✔
1764
  auto newAcceptContext = std::make_shared<DOHAcceptContext>();
2✔
1765
  setupAcceptContext(*newAcceptContext, *d_dsc, true);
2✔
1766
  std::atomic_store_explicit(&d_dsc->accept_ctx, std::move(newAcceptContext), std::memory_order_release);
2✔
1767
}
2✔
1768

1769
void H2ODOHFrontend::setup()
1770
{
23✔
1771
  registerOpenSSLUser();
23✔
1772

1773
  d_dsc = std::make_shared<DOHServerConfig>(d_idleTimeout, d_internalPipeBufferSize);
23✔
1774

1775
  if  (isHTTPS()) {
23✔
1776
    try {
22✔
1777
      setupTLSContext(*d_dsc->accept_ctx,
22✔
1778
                      d_tlsContext->d_tlsConfig,
22✔
1779
                      d_tlsContext->d_tlsCounters);
22✔
1780
    }
22✔
1781
    catch (const std::runtime_error& e) {
22✔
1782
      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext->d_addr.toStringWithPort() + "': " + e.what());
×
1783
    }
×
1784
  }
22✔
1785
}
23✔
1786

1787
#endif /* HAVE_LIBH2OEVLOOP */
1788
#endif /* HAVE_DNS_OVER_HTTPS */
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc