• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PowerDNS / pdns / 19741624072

27 Nov 2025 03:45PM UTC coverage: 73.086% (+0.02%) from 73.065%
19741624072

Pull #16570

github

web-flow
Merge 08a2cdb1d into f94a3f63f
Pull Request #16570: rec: rewrite all unwrap calls in web.rs

38523 of 63408 branches covered (60.75%)

Branch coverage included in aggregate %.

128044 of 164496 relevant lines covered (77.84%)

6531485.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.93
/pdns/dnsdistdist/dnsdist-ecs.cc
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#include "dolog.hh"
23
#include "dnsdist.hh"
24
#include "dnsdist-dnsparser.hh"
25
#include "dnsdist-ecs.hh"
26
#include "dnsparser.hh"
27
#include "dnswriter.hh"
28
#include "ednsoptions.hh"
29
#include "ednssubnet.hh"
30

31
int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent)
32
{
9✔
33
  if (initialPacket.size() < sizeof(dnsheader)) {
9!
34
    return ENOENT;
×
35
  }
×
36

37
  const dnsheader_aligned dnsHeader(initialPacket.data());
9✔
38

39
  if (ntohs(dnsHeader->arcount) == 0) {
9!
40
    return ENOENT;
×
41
  }
×
42

43
  if (ntohs(dnsHeader->qdcount) == 0) {
9!
44
    return ENOENT;
×
45
  }
×
46

47
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
48
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
9✔
49

50
  size_t idx = 0;
9✔
51
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
9✔
52
  uint16_t ancount = ntohs(dnsHeader->ancount);
9✔
53
  uint16_t nscount = ntohs(dnsHeader->nscount);
9✔
54
  uint16_t arcount = ntohs(dnsHeader->arcount);
9✔
55
  string blob;
9✔
56
  dnsrecordheader recordHeader{};
9✔
57

58
  auto rrname = packetReader.getName();
9✔
59
  auto rrtype = packetReader.get16BitInt();
9✔
60
  auto rrclass = packetReader.get16BitInt();
9✔
61

62
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
9✔
63
  packetWriter.getHeader()->id = dnsHeader->id;
9✔
64
  packetWriter.getHeader()->qr = dnsHeader->qr;
9✔
65
  packetWriter.getHeader()->aa = dnsHeader->aa;
9✔
66
  packetWriter.getHeader()->tc = dnsHeader->tc;
9✔
67
  packetWriter.getHeader()->rd = dnsHeader->rd;
9✔
68
  packetWriter.getHeader()->ra = dnsHeader->ra;
9✔
69
  packetWriter.getHeader()->ad = dnsHeader->ad;
9✔
70
  packetWriter.getHeader()->cd = dnsHeader->cd;
9✔
71
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
9✔
72

73
  /* consume remaining qd if any */
74
  if (qdcount > 1) {
9!
75
    for (idx = 1; idx < qdcount; idx++) {
×
76
      rrname = packetReader.getName();
×
77
      rrtype = packetReader.get16BitInt();
×
78
      rrclass = packetReader.get16BitInt();
×
79
      (void)rrtype;
×
80
      (void)rrclass;
×
81
    }
×
82
  }
×
83

84
  /* copy AN and NS */
85
  for (idx = 0; idx < ancount; idx++) {
18✔
86
    rrname = packetReader.getName();
9✔
87
    packetReader.getDnsrecordheader(recordHeader);
9✔
88

89
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
9✔
90
    packetReader.xfrBlob(blob);
9✔
91
    packetWriter.xfrBlob(blob);
9✔
92
  }
9✔
93

94
  for (idx = 0; idx < nscount; idx++) {
9!
95
    rrname = packetReader.getName();
×
96
    packetReader.getDnsrecordheader(recordHeader);
×
97

98
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
×
99
    packetReader.xfrBlob(blob);
×
100
    packetWriter.xfrBlob(blob);
×
101
  }
×
102
  /* consume AR, looking for OPT */
103
  for (idx = 0; idx < arcount; idx++) {
30✔
104
    rrname = packetReader.getName();
21✔
105
    packetReader.getDnsrecordheader(recordHeader);
21✔
106

107
    if (recordHeader.d_type != QType::OPT) {
21✔
108
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
12✔
109
      packetReader.xfrBlob(blob);
12✔
110
      packetWriter.xfrBlob(blob);
12✔
111
    }
12✔
112
    else {
9✔
113

114
      packetReader.skip(recordHeader.d_clen);
9✔
115
    }
9✔
116
  }
21✔
117
  packetWriter.commit();
9✔
118

119
  return 0;
9✔
120
}
9✔
121

122
static bool addOrReplaceEDNSOption(std::vector<std::pair<uint16_t, std::string>>& options, uint16_t optionCode, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
123
{
67✔
124
  for (auto it = options.begin(); it != options.end();) {
117✔
125
    if (it->first == optionCode) {
50✔
126
      optionAdded = false;
36✔
127

128
      if (!overrideExisting) {
36!
129
        return false;
×
130
      }
×
131

132
      it = options.erase(it);
36✔
133
    }
36✔
134
    else {
14✔
135
      ++it;
14✔
136
    }
14✔
137
  }
50✔
138

139
  options.emplace_back(optionCode, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
67✔
140
  return true;
67✔
141
}
67✔
142

143
bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
144
{
76✔
145
  if (initialPacket.size() < sizeof(dnsheader)) {
76!
146
    return false;
×
147
  }
×
148

149
  const dnsheader_aligned dnsHeader(initialPacket.data());
76✔
150

151
  if (ntohs(dnsHeader->qdcount) == 0) {
76!
152
    return false;
×
153
  }
×
154

155
  if (ntohs(dnsHeader->ancount) == 0 && ntohs(dnsHeader->nscount) == 0 && ntohs(dnsHeader->arcount) == 0) {
76!
156
    throw std::runtime_error("slowRewriteEDNSOptionInQueryWithRecords should not be called for queries that have no records");
×
157
  }
×
158

159
  optionAdded = false;
76✔
160
  ednsAdded = true;
76✔
161

162
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
163
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
76✔
164

165
  size_t idx = 0;
76✔
166
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
76✔
167
  uint16_t ancount = ntohs(dnsHeader->ancount);
76✔
168
  uint16_t nscount = ntohs(dnsHeader->nscount);
76✔
169
  uint16_t arcount = ntohs(dnsHeader->arcount);
76✔
170
  string blob;
76✔
171
  dnsrecordheader recordHeader{};
76✔
172

173
  auto rrname = packetReader.getName();
76✔
174
  auto rrtype = packetReader.get16BitInt();
76✔
175
  auto rrclass = packetReader.get16BitInt();
76✔
176

177
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
76✔
178
  packetWriter.getHeader()->id = dnsHeader->id;
76✔
179
  packetWriter.getHeader()->qr = dnsHeader->qr;
76✔
180
  packetWriter.getHeader()->aa = dnsHeader->aa;
76✔
181
  packetWriter.getHeader()->tc = dnsHeader->tc;
76✔
182
  packetWriter.getHeader()->rd = dnsHeader->rd;
76✔
183
  packetWriter.getHeader()->ra = dnsHeader->ra;
76✔
184
  packetWriter.getHeader()->ad = dnsHeader->ad;
76✔
185
  packetWriter.getHeader()->cd = dnsHeader->cd;
76✔
186
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
76✔
187

188
  /* consume remaining qd if any */
189
  if (qdcount > 1) {
76!
190
    for (idx = 1; idx < qdcount; idx++) {
×
191
      rrname = packetReader.getName();
×
192
      rrtype = packetReader.get16BitInt();
×
193
      rrclass = packetReader.get16BitInt();
×
194
      (void)rrtype;
×
195
      (void)rrclass;
×
196
    }
×
197
  }
×
198

199
  /* copy AN and NS */
200
  for (idx = 0; idx < ancount; idx++) {
117✔
201
    rrname = packetReader.getName();
41✔
202
    packetReader.getDnsrecordheader(recordHeader);
41✔
203

204
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
41✔
205
    packetReader.xfrBlob(blob);
41✔
206
    packetWriter.xfrBlob(blob);
41✔
207
  }
41✔
208

209
  for (idx = 0; idx < nscount; idx++) {
84✔
210
    rrname = packetReader.getName();
8✔
211
    packetReader.getDnsrecordheader(recordHeader);
8✔
212

213
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
8✔
214
    packetReader.xfrBlob(blob);
8✔
215
    packetWriter.xfrBlob(blob);
8✔
216
  }
8✔
217

218
  /* consume AR, looking for OPT */
219
  for (idx = 0; idx < arcount; idx++) {
177✔
220
    rrname = packetReader.getName();
101✔
221
    packetReader.getDnsrecordheader(recordHeader);
101✔
222

223
    if (recordHeader.d_type != QType::OPT) {
101✔
224
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
34✔
225
      packetReader.xfrBlob(blob);
34✔
226
      packetWriter.xfrBlob(blob);
34✔
227
    }
34✔
228
    else {
67✔
229

230
      ednsAdded = false;
67✔
231
      packetReader.xfrBlob(blob);
67✔
232

233
      std::vector<std::pair<uint16_t, std::string>> options;
67✔
234
      getEDNSOptionsFromContent(blob, options);
67✔
235

236
      /* getDnsrecordheader() has helpfully converted the TTL for us, which we do not want in that case */
237
      uint32_t ttl = htonl(recordHeader.d_ttl);
67✔
238
      EDNS0Record edns0{};
67✔
239
      static_assert(sizeof(edns0) == sizeof(ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
67✔
240
      memcpy(&edns0, &ttl, sizeof(edns0));
67✔
241

242
      /* addOrReplaceEDNSOption will set it to false if there is already an existing option */
243
      optionAdded = true;
67✔
244
      addOrReplaceEDNSOption(options, optionToReplace, optionAdded, overrideExisting, newOptionContent);
67✔
245
      packetWriter.addOpt(recordHeader.d_class, edns0.extRCode, ntohs(edns0.extFlags), options, edns0.version);
67✔
246
    }
67✔
247
  }
101✔
248

249
  if (ednsAdded) {
76✔
250
    packetWriter.addOpt(dnsdist::configuration::s_EdnsUDPPayloadSize, 0, 0, {{optionToReplace, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
9✔
251
    optionAdded = true;
9✔
252
  }
9✔
253

254
  packetWriter.commit();
76✔
255

256
  return true;
76✔
257
}
76✔
258

259
int locateEDNSOptRR(const PacketBuffer& packet, uint16_t* optStart, size_t* optLen, bool* last)
260
{
377✔
261
  if (optStart == nullptr || optLen == nullptr || last == nullptr) {
377!
262
    throw std::runtime_error("Invalid values passed to locateEDNSOptRR");
×
263
  }
×
264

265
  const dnsheader_aligned dnsHeader(packet.data());
377✔
266

267
  if (ntohs(dnsHeader->arcount) == 0) {
377✔
268
    return ENOENT;
214✔
269
  }
214✔
270

271
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
272
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()));
163✔
273

274
  size_t idx = 0;
163✔
275
  DNSName rrname;
163✔
276
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
163✔
277
  uint16_t ancount = ntohs(dnsHeader->ancount);
163✔
278
  uint16_t nscount = ntohs(dnsHeader->nscount);
163✔
279
  uint16_t arcount = ntohs(dnsHeader->arcount);
163✔
280
  uint16_t rrtype{};
163✔
281
  uint16_t rrclass{};
163✔
282
  dnsrecordheader recordHeader{};
163✔
283

284
  /* consume qd */
285
  for (idx = 0; idx < qdcount; idx++) {
326✔
286
    rrname = packetReader.getName();
163✔
287
    rrtype = packetReader.get16BitInt();
163✔
288
    rrclass = packetReader.get16BitInt();
163✔
289
    (void)rrtype;
163✔
290
    (void)rrclass;
163✔
291
  }
163✔
292

293
  /* consume AN and NS */
294
  for (idx = 0; idx < ancount + nscount; idx++) {
246✔
295
    rrname = packetReader.getName();
83✔
296
    packetReader.getDnsrecordheader(recordHeader);
83✔
297
    packetReader.skip(recordHeader.d_clen);
83✔
298
  }
83✔
299

300
  /* consume AR, looking for OPT */
301
  for (idx = 0; idx < arcount; idx++) {
180✔
302
    uint16_t start = packetReader.getPosition();
173✔
303
    rrname = packetReader.getName();
173✔
304
    packetReader.getDnsrecordheader(recordHeader);
173✔
305

306
    if (recordHeader.d_type == QType::OPT) {
173✔
307
      *optStart = start;
156✔
308
      *optLen = (packetReader.getPosition() - start) + recordHeader.d_clen;
156✔
309

310
      if (packet.size() < (*optStart + *optLen)) {
156✔
311
        throw std::range_error("Opt record overflow");
15✔
312
      }
15✔
313

314
      if (idx == ((size_t)arcount - 1)) {
141!
315
        *last = true;
141✔
316
      }
141✔
317
      else {
×
318
        *last = false;
×
319
      }
×
320
      return 0;
141✔
321
    }
156✔
322
    packetReader.skip(recordHeader.d_clen);
17✔
323
  }
17✔
324

325
  return ENOENT;
7✔
326
}
163✔
327

328
namespace dnsdist
329
{
330
/* extract the start of the OPT RR in a QUERY packet if any */
331
int getEDNSOptionsStart(const PacketBuffer& packet, const size_t qnameWireLength, uint16_t* optRDPosition, size_t* remaining)
332
{
1,554✔
333
  if (optRDPosition == nullptr || remaining == nullptr) {
1,554!
334
    throw std::runtime_error("Invalid values passed to getEDNSOptionsStart");
×
335
  }
×
336

337
  const dnsheader_aligned dnsHeader(packet.data());
1,554✔
338

339
  if (qnameWireLength >= packet.size()) {
1,554!
340
    return ENOENT;
×
341
  }
×
342

343
  if (ntohs(dnsHeader->qdcount) != 1 || ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->arcount) != 1 || ntohs(dnsHeader->nscount) != 0) {
1,554!
344
    return ENOENT;
998✔
345
  }
998✔
346

347
  size_t pos = sizeof(dnsheader) + qnameWireLength;
556✔
348
  pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
556✔
349

350
  if (pos >= packet.size()) {
556!
351
    return ENOENT;
×
352
  }
×
353

354
  if ((pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE) >= packet.size()) {
556!
355
    return ENOENT;
×
356
  }
×
357

358
  if (packet[pos] != 0) {
556✔
359
    /* not the root so not an OPT record */
360
    return ENOENT;
10✔
361
  }
10✔
362
  pos += 1;
546✔
363

364
  uint16_t qtype = packet.at(pos) * 256 + packet.at(pos + 1);
546✔
365
  pos += DNS_TYPE_SIZE;
546✔
366
  pos += DNS_CLASS_SIZE;
546✔
367

368
  if (qtype != QType::OPT || (packet.size() - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE)) {
546!
369
    return ENOENT;
24✔
370
  }
24✔
371

372
  pos += DNS_TTL_SIZE;
522✔
373
  *optRDPosition = pos;
522✔
374
  *remaining = packet.size() - pos;
522✔
375

376
  return 0;
522✔
377
}
546✔
378
}
379

380
void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
381
{
253✔
382
  Netmask sourceNetmask(source, ECSPrefixLength);
253✔
383
  EDNSSubnetOpts ecsOpts;
253✔
384
  ecsOpts.setSource(sourceNetmask);
253✔
385
  string payload = ecsOpts.makeOptString();
253✔
386
  generateEDNSOption(EDNSOptionCode::ECS, payload, res);
253✔
387
}
253✔
388

389
bool generateOptRR(const std::string& optRData, PacketBuffer& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
390
{
515✔
391
  const uint8_t name = 0;
515✔
392
  dnsrecordheader dnsHeader{};
515✔
393
  EDNS0Record edns0{};
515✔
394
  edns0.extRCode = ednsrcode;
515✔
395
  edns0.version = 0;
515✔
396
  edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;
515✔
397

398
  if ((maximumSize - res.size()) < (sizeof(name) + sizeof(dnsHeader) + optRData.length())) {
515✔
399
    return false;
6✔
400
  }
6✔
401

402
  dnsHeader.d_type = htons(QType::OPT);
509✔
403
  dnsHeader.d_class = htons(udpPayloadSize);
509✔
404
  static_assert(sizeof(EDNS0Record) == sizeof(dnsHeader.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
509✔
405
  memcpy(&dnsHeader.d_ttl, &edns0, sizeof edns0);
509✔
406
  dnsHeader.d_clen = htons(static_cast<uint16_t>(optRData.length()));
509✔
407

408
  res.reserve(res.size() + sizeof(name) + sizeof(dnsHeader) + optRData.length());
509✔
409
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
410
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&name), reinterpret_cast<const uint8_t*>(&name) + sizeof(name));
509✔
411
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
412
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&dnsHeader), reinterpret_cast<const uint8_t*>(&dnsHeader) + sizeof(dnsHeader));
509✔
413
  res.insert(res.end(), optRData.begin(), optRData.end());
509✔
414

415
  return true;
509✔
416
}
515✔
417

418
static bool replaceEDNSClientSubnetOption(PacketBuffer& packet, size_t maximumSize, size_t const oldEcsOptionStartPosition, size_t const oldEcsOptionSize, size_t const optRDLenPosition, const string& newECSOption)
419
{
27✔
420
  if (oldEcsOptionStartPosition >= packet.size() || optRDLenPosition >= packet.size()) {
27!
421
    throw std::runtime_error("Invalid values passed to replaceEDNSClientSubnetOption");
×
422
  }
×
423

424
  if (newECSOption.size() == oldEcsOptionSize) {
27✔
425
    /* same size as the existing option */
426
    memcpy(&packet.at(oldEcsOptionStartPosition), newECSOption.c_str(), oldEcsOptionSize);
14✔
427
  }
14✔
428
  else {
13✔
429
    /* different size than the existing option */
430
    const unsigned int newPacketLen = packet.size() + (newECSOption.length() - oldEcsOptionSize);
13✔
431
    const size_t beforeOptionLen = oldEcsOptionStartPosition;
13✔
432
    const size_t dataBehindSize = packet.size() - beforeOptionLen - oldEcsOptionSize;
13✔
433

434
    /* check that it fits in the existing buffer */
435
    if (newPacketLen > packet.size()) {
13✔
436
      if (newPacketLen > maximumSize) {
8✔
437
        return false;
3✔
438
      }
3✔
439

440
      packet.resize(newPacketLen);
5✔
441
    }
5✔
442

443
    /* fix the size of ECS Option RDLen */
444
    uint16_t newRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
10✔
445
    newRDLen += (newECSOption.size() - oldEcsOptionSize);
10✔
446
    packet.at(optRDLenPosition) = newRDLen / 256;
10✔
447
    packet.at(optRDLenPosition + 1) = newRDLen % 256;
10✔
448

449
    if (dataBehindSize > 0) {
10!
450
      memmove(&packet.at(oldEcsOptionStartPosition), &packet.at(oldEcsOptionStartPosition + oldEcsOptionSize), dataBehindSize);
×
451
    }
×
452
    memcpy(&packet.at(oldEcsOptionStartPosition + dataBehindSize), newECSOption.c_str(), newECSOption.size());
10✔
453
    packet.resize(newPacketLen);
10✔
454
  }
10✔
455

456
  return true;
24✔
457
}
27✔
458

459
/* This function looks for an OPT RR, return true if a valid one was found (even if there was no options)
460
   and false otherwise. */
461
bool parseEDNSOptions(const DNSQuestion& dnsQuestion)
462
{
124✔
463
  const auto dnsHeader = dnsQuestion.getHeader();
124✔
464
  if (dnsQuestion.ednsOptions != nullptr) {
124!
465
    return true;
×
466
  }
×
467

468
  // dnsQuestion.ednsOptions is mutable
469
  dnsQuestion.ednsOptions = std::make_unique<EDNSOptionViewMap>();
124✔
470

471
  if (ntohs(dnsHeader->arcount) == 0) {
124✔
472
    /* nothing in additional so no EDNS */
473
    return false;
22✔
474
  }
22✔
475

476
  if (ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->nscount) != 0 || ntohs(dnsHeader->arcount) > 1) {
102✔
477
    return slowParseEDNSOptions(dnsQuestion.getData(), *dnsQuestion.ednsOptions);
29✔
478
  }
29✔
479

480
  size_t remaining = 0;
73✔
481
  uint16_t optRDPosition{};
73✔
482
  int res = dnsdist::getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
73✔
483

484
  if (res == 0) {
73!
485
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
486
    res = getEDNSOptions(reinterpret_cast<const char*>(&dnsQuestion.getData().at(optRDPosition)), remaining, *dnsQuestion.ednsOptions);
73✔
487
    return (res == 0);
73✔
488
  }
73✔
489

490
  return false;
×
491
}
73✔
492

493
static bool addECSToExistingOPT(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, size_t optRDLenPosition, bool& ecsAdded)
494
{
32✔
495
  /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
496
  /* getEDNSOptionsStart has already checked that there is exactly one AR,
497
     no NS and no AN */
498
  uint16_t oldRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
32✔
499
  if (packet.size() != (optRDLenPosition + sizeof(uint16_t) + oldRDLen)) {
32✔
500
    /* we are supposed to be the last record, do we have some trailing data to remove? */
501
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
502
    uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
3✔
503
    packet.resize(realPacketLen);
3✔
504
  }
3✔
505

506
  if ((maximumSize - packet.size()) < newECSOption.size()) {
32✔
507
    return false;
3✔
508
  }
3✔
509

510
  uint16_t newRDLen = oldRDLen + newECSOption.size();
29✔
511
  packet.at(optRDLenPosition) = newRDLen / 256;
29✔
512
  packet.at(optRDLenPosition + 1) = newRDLen % 256;
29✔
513

514
  packet.insert(packet.end(), newECSOption.begin(), newECSOption.end());
29✔
515
  ecsAdded = true;
29✔
516

517
  return true;
29✔
518
}
32✔
519

520
static bool addEDNSWithECS(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, bool& ednsAdded, bool& ecsAdded)
521
{
118✔
522
  if (!generateOptRR(newECSOption, packet, maximumSize, dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) {
118✔
523
    return false;
6✔
524
  }
6✔
525

526
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
112✔
527
    uint16_t arcount = ntohs(header.arcount);
112✔
528
    arcount++;
112✔
529
    header.arcount = htons(arcount);
112✔
530
    return true;
112✔
531
  });
112✔
532
  ednsAdded = true;
112✔
533
  ecsAdded = true;
112✔
534

535
  return true;
112✔
536
}
118✔
537

538
bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, const size_t qnameWireLength, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
539
{
238✔
540
  if (qnameWireLength > packet.size()) {
238!
541
    throw std::runtime_error("Invalid value passed to handleEDNSClientSubnet");
×
542
  }
×
543

544
  const dnsheader_aligned dnsHeader(packet.data());
238✔
545

546
  if (ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->nscount) != 0 || (ntohs(dnsHeader->arcount) != 0 && ntohs(dnsHeader->arcount) != 1)) {
238✔
547
    PacketBuffer newContent;
47✔
548
    newContent.reserve(packet.size());
47✔
549

550
    if (!slowRewriteEDNSOptionInQueryWithRecords(packet, newContent, ednsAdded, EDNSOptionCode::ECS, ecsAdded, overrideExisting, newECSOption)) {
47!
551
      return false;
×
552
    }
×
553

554
    if (newContent.size() > maximumSize) {
47✔
555
      ednsAdded = false;
18✔
556
      ecsAdded = false;
18✔
557
      return false;
18✔
558
    }
18✔
559

560
    packet = std::move(newContent);
29✔
561
    return true;
29✔
562
  }
47✔
563

564
  uint16_t optRDPosition = 0;
191✔
565
  size_t remaining = 0;
191✔
566

567
  int res = dnsdist::getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
191✔
568

569
  if (res != 0) {
191✔
570
    /* no EDNS but there might be another record in additional (TSIG?) */
571
    /* Careful, this code assumes that ANCOUNT == 0 && NSCOUNT == 0 */
572
    size_t minimumPacketSize = sizeof(dnsheader) + qnameWireLength + sizeof(uint16_t) + sizeof(uint16_t);
118✔
573
    if (packet.size() > minimumPacketSize) {
118✔
574
      if (ntohs(dnsHeader->arcount) == 0) {
12✔
575
        /* well now.. */
576
        packet.resize(minimumPacketSize);
6✔
577
      }
6✔
578
      else {
6✔
579
        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
580
        uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
6✔
581
        packet.resize(realPacketLen);
6✔
582
      }
6✔
583
    }
12✔
584

585
    return addEDNSWithECS(packet, maximumSize, newECSOption, ednsAdded, ecsAdded);
118✔
586
  }
118✔
587

588
  size_t ecsOptionStartPosition = 0;
73✔
589
  size_t ecsOptionSize = 0;
73✔
590
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
591
  res = getEDNSOption(reinterpret_cast<const char*>(&packet.at(optRDPosition)), remaining, EDNSOptionCode::ECS, &ecsOptionStartPosition, &ecsOptionSize);
73✔
592

593
  if (res == 0) {
73✔
594
    /* there is already an ECS value */
595
    if (!overrideExisting) {
41✔
596
      return true;
14✔
597
    }
14✔
598

599
    return replaceEDNSClientSubnetOption(packet, maximumSize, optRDPosition + ecsOptionStartPosition, ecsOptionSize, optRDPosition, newECSOption);
27✔
600
  }
41✔
601

602
  /* we have an EDNS OPT RR but no existing ECS option */
603
  return addECSToExistingOPT(packet, maximumSize, newECSOption, optRDPosition, ecsAdded);
32✔
604
}
73✔
605

606
bool handleEDNSClientSubnet(DNSQuestion& dnsQuestion, bool& ednsAdded, bool& ecsAdded)
607
{
156✔
608
  string newECSOption;
156✔
609
  generateECSOption(dnsQuestion.ecs ? dnsQuestion.ecs->getNetwork() : dnsQuestion.ids.origRemote, newECSOption, dnsQuestion.ecs ? dnsQuestion.ecs->getBits() : dnsQuestion.ecsPrefixLength);
156✔
610

611
  return handleEDNSClientSubnet(dnsQuestion.getMutableData(), dnsQuestion.getMaximumSize(), dnsQuestion.ids.qname.wirelength(), ednsAdded, ecsAdded, dnsQuestion.ecsOverride, newECSOption);
156✔
612
}
156✔
613

614
static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen)
615
{
42✔
616
  const pdns::views::UnsignedCharView view(optionsStart, optionsLen);
42✔
617
  size_t pos = 0;
42✔
618
  while ((pos + 4) <= view.size()) {
62✔
619
    size_t optionBeginPos = pos;
54✔
620
    const uint16_t optionCode = 0x100 * view.at(pos) + view.at(pos + 1);
54✔
621
    pos += sizeof(optionCode);
54✔
622
    const uint16_t optionLen = 0x100 * view.at(pos) + view.at(pos + 1);
54✔
623
    pos += sizeof(optionLen);
54✔
624
    if ((pos + optionLen) > view.size()) {
54!
625
      return EINVAL;
×
626
    }
×
627
    if (optionCode == optionCodeToRemove) {
54✔
628
      if (pos + optionLen < view.size()) {
34✔
629
        /* move remaining options over the removed one,
630
           if any */
631
        // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
632
        memmove(optionsStart + optionBeginPos, optionsStart + pos + optionLen, optionsLen - (pos + optionLen));
16✔
633
      }
16✔
634
      *newOptionsLen = optionsLen - (sizeof(optionCode) + sizeof(optionLen) + optionLen);
34✔
635
      return 0;
34✔
636
    }
34✔
637
    pos += optionLen;
20✔
638
  }
20✔
639
  return ENOENT;
8✔
640
}
42✔
641

642
int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
643
{
30✔
644
  if (*optLen < optRecordMinimumSize) {
30!
645
    return EINVAL;
×
646
  }
×
647
  const pdns::views::UnsignedCharView view(optStart, *optLen);
30✔
648
  /* skip the root label, qtype, qclass and TTL */
649
  size_t position = 9;
30✔
650
  uint16_t rdLen = (0x100 * view.at(position) + view.at(position + 1));
30✔
651
  position += sizeof(rdLen);
30✔
652
  if (position + rdLen != view.size()) {
30!
653
    return EINVAL;
×
654
  }
×
655
  uint16_t newRdLen = 0;
30✔
656
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
657
  int res = removeEDNSOptionFromOptions(reinterpret_cast<unsigned char*>(optStart + position), rdLen, optionCodeToRemove, &newRdLen);
30✔
658
  if (res != 0) {
30✔
659
    return res;
8✔
660
  }
8✔
661
  *optLen -= (rdLen - newRdLen);
22✔
662
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
663
  auto* rdLenPtr = reinterpret_cast<unsigned char*>(optStart + 9);
22✔
664
  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
665
  rdLenPtr[0] = newRdLen / 0x100;
22✔
666
  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
667
  rdLenPtr[1] = newRdLen % 0x100;
22✔
668
  return 0;
22✔
669
}
30✔
670

671
bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
672
{
45✔
673
  if (optLen < optRecordMinimumSize) {
45!
674
    return false;
×
675
  }
×
676
  size_t position = optStart + 9;
45✔
677
  uint16_t rdLen = (0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1)));
45✔
678
  position += sizeof(rdLen);
45✔
679
  if (rdLen > (optLen - optRecordMinimumSize)) {
45!
680
    return false;
×
681
  }
×
682

683
  size_t rdEnd = position + rdLen;
45✔
684
  while ((position + 4) <= rdEnd) {
63✔
685
    const uint16_t optionCode = 0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1));
52✔
686
    position += sizeof(optionCode);
52✔
687
    const uint16_t optionLen = 0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1));
52✔
688
    position += sizeof(optionLen);
52✔
689

690
    if ((position + optionLen) > rdEnd) {
52!
691
      return false;
×
692
    }
×
693

694
    if (optionCode == optionCodeToFind) {
52✔
695
      if (optContentStart != nullptr) {
34✔
696
        *optContentStart = position;
32✔
697
      }
32✔
698

699
      if (optContentLen != nullptr) {
34✔
700
        *optContentLen = optionLen;
32✔
701
      }
32✔
702

703
      return true;
34✔
704
    }
34✔
705
    position += optionLen;
18✔
706
  }
18✔
707
  return false;
11✔
708
}
45✔
709

710
int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent)
711
{
12✔
712
  if (initialPacket.size() < sizeof(dnsheader)) {
12!
713
    return ENOENT;
×
714
  }
×
715

716
  const dnsheader_aligned dnsHeader(initialPacket.data());
12✔
717

718
  if (ntohs(dnsHeader->arcount) == 0) {
12!
719
    return ENOENT;
×
720
  }
×
721

722
  if (ntohs(dnsHeader->qdcount) == 0) {
12!
723
    return ENOENT;
×
724
  }
×
725

726
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
727
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
12✔
728

729
  size_t idx = 0;
12✔
730
  DNSName rrname;
12✔
731
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
12✔
732
  uint16_t ancount = ntohs(dnsHeader->ancount);
12✔
733
  uint16_t nscount = ntohs(dnsHeader->nscount);
12✔
734
  uint16_t arcount = ntohs(dnsHeader->arcount);
12✔
735
  uint16_t rrtype = 0;
12✔
736
  uint16_t rrclass = 0;
12✔
737
  string blob;
12✔
738
  dnsrecordheader recordHeader{};
12✔
739

740
  rrname = packetReader.getName();
12✔
741
  rrtype = packetReader.get16BitInt();
12✔
742
  rrclass = packetReader.get16BitInt();
12✔
743

744
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
12✔
745
  packetWriter.getHeader()->id = dnsHeader->id;
12✔
746
  packetWriter.getHeader()->qr = dnsHeader->qr;
12✔
747
  packetWriter.getHeader()->aa = dnsHeader->aa;
12✔
748
  packetWriter.getHeader()->tc = dnsHeader->tc;
12✔
749
  packetWriter.getHeader()->rd = dnsHeader->rd;
12✔
750
  packetWriter.getHeader()->ra = dnsHeader->ra;
12✔
751
  packetWriter.getHeader()->ad = dnsHeader->ad;
12✔
752
  packetWriter.getHeader()->cd = dnsHeader->cd;
12✔
753
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
12✔
754

755
  /* consume remaining qd if any */
756
  if (qdcount > 1) {
12!
757
    for (idx = 1; idx < qdcount; idx++) {
×
758
      rrname = packetReader.getName();
×
759
      rrtype = packetReader.get16BitInt();
×
760
      rrclass = packetReader.get16BitInt();
×
761
      (void)rrtype;
×
762
      (void)rrclass;
×
763
    }
×
764
  }
×
765

766
  /* copy AN and NS */
767
  for (idx = 0; idx < ancount; idx++) {
24✔
768
    rrname = packetReader.getName();
12✔
769
    packetReader.getDnsrecordheader(recordHeader);
12✔
770

771
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
12✔
772
    packetReader.xfrBlob(blob);
12✔
773
    packetWriter.xfrBlob(blob);
12✔
774
  }
12✔
775

776
  for (idx = 0; idx < nscount; idx++) {
12!
777
    rrname = packetReader.getName();
×
778
    packetReader.getDnsrecordheader(recordHeader);
×
779

780
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
×
781
    packetReader.xfrBlob(blob);
×
782
    packetWriter.xfrBlob(blob);
×
783
  }
×
784

785
  /* consume AR, looking for OPT */
786
  for (idx = 0; idx < arcount; idx++) {
36✔
787
    rrname = packetReader.getName();
24✔
788
    packetReader.getDnsrecordheader(recordHeader);
24✔
789

790
    if (recordHeader.d_type != QType::OPT) {
24✔
791
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
12✔
792
      packetReader.xfrBlob(blob);
12✔
793
      packetWriter.xfrBlob(blob);
12✔
794
    }
12✔
795
    else {
12✔
796
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, false);
12✔
797
      packetReader.xfrBlob(blob);
12✔
798
      uint16_t rdLen = blob.length();
12✔
799
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
800
      removeEDNSOptionFromOptions(reinterpret_cast<unsigned char*>(blob.data()), rdLen, optionCodeToSkip, &rdLen);
12✔
801
      /* xfrBlob(string, size) completely ignores size.. */
802
      if (rdLen > 0) {
12✔
803
        blob.resize((size_t)rdLen);
9✔
804
        packetWriter.xfrBlob(blob);
9✔
805
      }
9✔
806
      else {
3✔
807
        packetWriter.commit();
3✔
808
      }
3✔
809
    }
12✔
810
  }
24✔
811
  packetWriter.commit();
12✔
812

813
  return 0;
12✔
814
}
12✔
815

816
bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
817
{
393✔
818
  if (!generateOptRR(std::string(), packet, maximumSize, payloadSize, ednsrcode, dnssecOK)) {
393!
819
    return false;
×
820
  }
×
821

822
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
393✔
823
    header.arcount = htons(ntohs(header.arcount) + 1);
393✔
824
    return true;
393✔
825
  });
393✔
826

827
  return true;
393✔
828
}
393✔
829

830
/*
831
  This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
832
  generating a NXD or NODATA answer with a SOA record in the additional section (or optionally the authority section for a full cacheable NXDOMAIN/NODATA).
833
*/
834
bool setNegativeAndAdditionalSOA(DNSQuestion& dnsQuestion, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, bool soaInAuthoritySection)
835
{
40✔
836
  auto& packet = dnsQuestion.getMutableData();
40✔
837
  auto dnsHeader = dnsQuestion.getHeader();
40✔
838
  if (ntohs(dnsHeader->qdcount) != 1) {
40!
839
    return false;
×
840
  }
×
841

842
  size_t queryPartSize = sizeof(dnsheader) + dnsQuestion.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
40✔
843
  if (packet.size() < queryPartSize) {
40!
844
    /* something is already wrong, don't build on flawed foundations */
845
    return false;
×
846
  }
×
847

848
  uint16_t qtype = htons(QType::SOA);
40✔
849
  uint16_t qclass = htons(QClass::IN);
40✔
850
  uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum);
40✔
851
  size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength;
40✔
852
  bool hadEDNS = false;
40✔
853
  bool dnssecOK = false;
40✔
854

855
  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses) {
40!
856
    uint16_t payloadSize = 0;
40✔
857
    uint16_t zValue = 0;
40✔
858
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
859
    hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(packet.data()), packet.size(), &payloadSize, &zValue);
40✔
860
    if (hadEDNS) {
40✔
861
      dnssecOK = (zValue & EDNS_HEADER_FLAG_DO) != 0;
20✔
862
    }
20✔
863
  }
40✔
864

865
  /* chop off everything after the question */
866
  packet.resize(queryPartSize);
40✔
867
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [nxd](dnsheader& header) {
40✔
868
    if (nxd) {
40✔
869
      header.rcode = RCode::NXDomain;
20✔
870
    }
20✔
871
    else {
20✔
872
      header.rcode = RCode::NoError;
20✔
873
    }
20✔
874
    header.qr = true;
40✔
875
    header.ancount = 0;
40✔
876
    header.nscount = 0;
40✔
877
    header.arcount = 0;
40✔
878
    return true;
40✔
879
  });
40✔
880

881
  rdLength = htons(rdLength);
40✔
882
  ttl = htonl(ttl);
40✔
883
  serial = htonl(serial);
40✔
884
  refresh = htonl(refresh);
40✔
885
  retry = htonl(retry);
40✔
886
  expire = htonl(expire);
40✔
887
  minimum = htonl(minimum);
40✔
888

889
  std::string soa;
40✔
890
  soa.reserve(soaSize);
40✔
891
  soa.append(zone.toDNSString());
40✔
892
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
893
  soa.append(reinterpret_cast<const char*>(&qtype), sizeof(qtype));
40✔
894
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
895
  soa.append(reinterpret_cast<const char*>(&qclass), sizeof(qclass));
40✔
896
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
897
  soa.append(reinterpret_cast<const char*>(&ttl), sizeof(ttl));
40✔
898
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
899
  soa.append(reinterpret_cast<const char*>(&rdLength), sizeof(rdLength));
40✔
900
  soa.append(mname.toDNSString());
40✔
901
  soa.append(rname.toDNSString());
40✔
902
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
903
  soa.append(reinterpret_cast<const char*>(&serial), sizeof(serial));
40✔
904
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
905
  soa.append(reinterpret_cast<const char*>(&refresh), sizeof(refresh));
40✔
906
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
907
  soa.append(reinterpret_cast<const char*>(&retry), sizeof(retry));
40✔
908
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
909
  soa.append(reinterpret_cast<const char*>(&expire), sizeof(expire));
40✔
910
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
911
  soa.append(reinterpret_cast<const char*>(&minimum), sizeof(minimum));
40✔
912

913
  if (soa.size() != soaSize) {
40!
914
    throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize));
×
915
  }
×
916

917
  packet.insert(packet.end(), soa.begin(), soa.end());
40✔
918

919
  /* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
920
     NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
921
     and have EDNS added to AR afterwards */
922
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [soaInAuthoritySection](dnsheader& header) {
40✔
923
    if (soaInAuthoritySection) {
40✔
924
      header.nscount = htons(1);
20✔
925
    }
20✔
926
    else {
20✔
927
      header.arcount = htons(1);
20✔
928
    }
20✔
929
    return true;
40✔
930
  });
40✔
931

932
  if (hadEDNS) {
40✔
933
    /* now we need to add a new OPT record */
934
    return addEDNS(packet, dnsQuestion.getMaximumSize(), dnssecOK, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, dnsQuestion.ednsRCode);
20✔
935
  }
20✔
936

937
  return true;
20✔
938
}
40✔
939

940
bool addEDNSToQueryTurnedResponse(DNSQuestion& dnsQuestion)
941
{
260✔
942
  uint16_t optRDPosition{};
260✔
943
  /* remaining is at least the size of the rdlen + the options if any + the following records if any */
944
  size_t remaining = 0;
260✔
945

946
  auto& packet = dnsQuestion.getMutableData();
260✔
947
  int res = dnsdist::getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
260✔
948

949
  if (res != 0) {
260✔
950
    /* if the initial query did not have EDNS0, we are done */
951
    return true;
191✔
952
  }
191✔
953

954
  const size_t existingOptLen = /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2 + remaining;
69✔
955
  if (existingOptLen >= packet.size()) {
69!
956
    /* something is wrong, bail out */
957
    return false;
×
958
  }
×
959

960
  const size_t optPosition = (optRDPosition - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2));
69✔
961

962
  size_t zPosition = optPosition + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
69✔
963
  uint16_t zValue = 0x100 * packet.at(zPosition) + packet.at(zPosition + 1);
69✔
964
  bool dnssecOK = (zValue & EDNS_HEADER_FLAG_DO) != 0;
69✔
965

966
  /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
967
  packet.resize(packet.size() - existingOptLen);
69✔
968
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
69✔
969
    header.arcount = 0;
69✔
970
    return true;
69✔
971
  });
69✔
972

973
  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses) {
69✔
974
    /* now we need to add a new OPT record */
975
    return addEDNS(packet, dnsQuestion.getMaximumSize(), dnssecOK, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, dnsQuestion.ednsRCode);
68✔
976
  }
68✔
977

978
  /* otherwise we are just fine */
979
  return true;
1✔
980
}
69✔
981

982
namespace dnsdist
983
{
984
static std::optional<size_t> getEDNSRecordPosition(const DNSQuestion& dnsQuestion)
985
{
535✔
986
  try {
535✔
987
    const auto& packet = dnsQuestion.getData();
535✔
988
    if (packet.size() <= sizeof(dnsheader)) {
535!
989
      return std::nullopt;
×
990
    }
×
991

992
    uint16_t optRDPosition = 0;
535✔
993
    size_t remaining = 0;
535✔
994
    auto res = getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
535✔
995
    if (res != 0) {
535✔
996
      return std::nullopt;
334✔
997
    }
334✔
998

999
    if (optRDPosition < DNS_TTL_SIZE) {
201!
1000
      return std::nullopt;
×
1001
    }
×
1002

1003
    return optRDPosition - DNS_TTL_SIZE;
201✔
1004
  }
201✔
1005
  catch (...) {
535✔
1006
    return std::nullopt;
×
1007
  }
×
1008
}
535✔
1009

1010
// goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
1011
int getEDNSZ(const DNSQuestion& dnsQuestion)
1012
{
459✔
1013
  try {
459✔
1014
    auto position = getEDNSRecordPosition(dnsQuestion);
459✔
1015

1016
    if (!position) {
459✔
1017
      return 0;
294✔
1018
    }
294✔
1019

1020
    const auto& packet = dnsQuestion.getData();
165✔
1021
    if ((*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= packet.size()) {
165!
1022
      return 0;
×
1023
    }
×
1024

1025
    return 0x100 * packet.at(*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) + packet.at(*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1);
165✔
1026
  }
165✔
1027
  catch (...) {
459✔
1028
    return 0;
×
1029
  }
×
1030
}
459✔
1031

1032
std::optional<uint8_t> getEDNSVersion(const DNSQuestion& dnsQuestion)
1033
{
38✔
1034
  try {
38✔
1035
    auto position = getEDNSRecordPosition(dnsQuestion);
38✔
1036

1037
    if (!position) {
38✔
1038
      return std::nullopt;
20✔
1039
    }
20✔
1040

1041
    const auto& packet = dnsQuestion.getData();
18✔
1042
    if ((*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) >= packet.size()) {
18!
1043
      return std::nullopt;
×
1044
    }
×
1045

1046
    return packet.at(*position + EDNS_EXTENDED_RCODE_SIZE);
18✔
1047
  }
18✔
1048
  catch (...) {
38✔
1049
    return std::nullopt;
×
1050
  }
×
1051
}
38✔
1052

1053
std::optional<uint8_t> getEDNSExtendedRCode(const DNSQuestion& dnsQuestion)
1054
{
38✔
1055
  try {
38✔
1056
    auto position = getEDNSRecordPosition(dnsQuestion);
38✔
1057

1058
    if (!position) {
38✔
1059
      return std::nullopt;
20✔
1060
    }
20✔
1061

1062
    const auto& packet = dnsQuestion.getData();
18✔
1063
    if ((*position + EDNS_EXTENDED_RCODE_SIZE) >= packet.size()) {
18!
1064
      return std::nullopt;
×
1065
    }
×
1066

1067
    return packet.at(*position);
18✔
1068
  }
18✔
1069
  catch (...) {
38✔
1070
    return std::nullopt;
×
1071
  }
×
1072
}
38✔
1073

1074
}
1075

1076
bool queryHasEDNS(const DNSQuestion& dnsQuestion)
1077
{
415✔
1078
  uint16_t optRDPosition = 0;
415✔
1079
  size_t ecsRemaining = 0;
415✔
1080

1081
  int res = dnsdist::getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &ecsRemaining);
415✔
1082
  return res == 0;
415✔
1083
}
415✔
1084

1085
bool getEDNS0Record(const PacketBuffer& packet, EDNS0Record& edns0)
1086
{
45✔
1087
  uint16_t optStart = 0;
45✔
1088
  size_t optLen = 0;
45✔
1089
  bool last = false;
45✔
1090
  int res = locateEDNSOptRR(packet, &optStart, &optLen, &last);
45✔
1091
  if (res != 0) {
45✔
1092
    // no EDNS OPT RR
1093
    return false;
34✔
1094
  }
34✔
1095

1096
  if (optLen < optRecordMinimumSize) {
11!
1097
    return false;
×
1098
  }
×
1099

1100
  if (optStart < packet.size() && packet.at(optStart) != 0) {
11!
1101
    // OPT RR Name != '.'
1102
    return false;
×
1103
  }
×
1104

1105
  static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
11✔
1106
  // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
1107
  memcpy(&edns0, &packet.at(optStart + 5), sizeof edns0);
11✔
1108
  return true;
11✔
1109
}
11✔
1110

1111
bool setEDNSOption(DNSQuestion& dnsQuestion, uint16_t ednsCode, const std::string& ednsData, bool isQuery)
1112
{
15✔
1113
  std::string optRData;
15✔
1114
  generateEDNSOption(ednsCode, ednsData, optRData);
15✔
1115

1116
  if (dnsQuestion.getHeader()->arcount != 0) {
15✔
1117
    bool ednsAdded = false;
11✔
1118
    bool optionAdded = false;
11✔
1119
    PacketBuffer newContent;
11✔
1120
    newContent.reserve(dnsQuestion.getData().size());
11✔
1121

1122
    if (!slowRewriteEDNSOptionInQueryWithRecords(dnsQuestion.getData(), newContent, ednsAdded, ednsCode, optionAdded, true, optRData)) {
11!
1123
      return false;
×
1124
    }
×
1125

1126
    if (newContent.size() > dnsQuestion.getMaximumSize()) {
11!
1127
      return false;
×
1128
    }
×
1129

1130
    dnsQuestion.getMutableData() = std::move(newContent);
11✔
1131
    if (isQuery && !dnsQuestion.ids.ednsAdded && ednsAdded) {
11!
1132
      dnsQuestion.ids.ednsAdded = true;
×
1133
    }
×
1134

1135
    return true;
11✔
1136
  }
11✔
1137

1138
  auto& data = dnsQuestion.getMutableData();
4✔
1139
  if (generateOptRR(optRData, data, dnsQuestion.getMaximumSize(), dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) {
4!
1140
    dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [](dnsheader& header) {
4✔
1141
      header.arcount = htons(1);
4✔
1142
      return true;
4✔
1143
    });
4✔
1144

1145
    if (isQuery) {
4!
1146
      // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
1147
      dnsQuestion.ids.ednsAdded = true;
4✔
1148
    }
4✔
1149
  }
4✔
1150

1151
  return true;
4✔
1152
}
15✔
1153

1154
namespace dnsdist
1155
{
1156
bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uint8_t rcode, bool clearAnswers)
1157
{
36✔
1158
  const auto qnameLength = state.qname.wirelength();
36✔
1159
  if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) {
36!
1160
    return false;
×
1161
  }
×
1162

1163
  EDNS0Record edns0{};
36✔
1164
  bool hadEDNS = false;
36✔
1165
  if (clearAnswers) {
36!
1166
    hadEDNS = getEDNS0Record(buffer, edns0);
36✔
1167
  }
36✔
1168

1169
  dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [rcode, clearAnswers](dnsheader& header) {
36✔
1170
    header.rcode = rcode;
36✔
1171
    header.ad = false;
36✔
1172
    header.aa = false;
36✔
1173
    header.ra = header.rd;
36✔
1174
    header.qr = true;
36✔
1175

1176
    if (clearAnswers) {
36!
1177
      header.ancount = 0;
36✔
1178
      header.nscount = 0;
36✔
1179
      header.arcount = 0;
36✔
1180
    }
36✔
1181
    return true;
36✔
1182
  });
36✔
1183

1184
  if (clearAnswers) {
36!
1185
    buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
36✔
1186
    if (hadEDNS) {
36✔
1187
      DNSQuestion dnsQuestion(state, buffer);
4✔
1188
      if (!addEDNS(buffer, dnsQuestion.getMaximumSize(), (edns0.extFlags & htons(EDNS_HEADER_FLAG_DO)) != 0, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, 0)) {
4!
1189
        return false;
×
1190
      }
×
1191
    }
4✔
1192
  }
36✔
1193

1194
  return true;
36✔
1195
}
36✔
1196
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc