• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 17470957362

04 Sep 2025 04:50PM UTC coverage: 66.654% (-0.02%) from 66.67%
17470957362

push

github

web-flow
Merge pull request #10199 from ellemouton/fixBuild

graph/db: fix type name

0 of 1 new or added line in 1 file covered. (0.0%)

400 existing lines in 24 files now uncovered.

136139 of 204247 relevant lines covered (66.65%)

21439.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_migration.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "cmp"
6
        "context"
7
        "database/sql"
8
        "errors"
9
        "fmt"
10
        "net"
11
        "slices"
12
        "time"
13

14
        "github.com/btcsuite/btcd/chaincfg/chainhash"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/kvdb"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/routing/route"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
        "golang.org/x/time/rate"
22
)
23

24
// MigrateGraphToSQL migrates the graph store from a KV backend to a SQL
25
// backend.
26
//
27
// NOTE: this is currently not called from any code path. It is called via tests
28
// only for now and will be called from the main lnd binary once the
29
// migration is fully implemented and tested.
30
func MigrateGraphToSQL(ctx context.Context, cfg *SQLStoreConfig,
31
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
32

×
33
        log.Infof("Starting migration of the graph store from KV to SQL")
×
34
        t0 := time.Now()
×
35

×
36
        // Check if there is a graph to migrate.
×
37
        graphExists, err := checkGraphExists(kvBackend)
×
38
        if err != nil {
×
39
                return fmt.Errorf("failed to check graph existence: %w", err)
×
40
        }
×
41
        if !graphExists {
×
42
                log.Infof("No graph found in KV store, skipping the migration")
×
43
                return nil
×
44
        }
×
45

46
        // 1) Migrate all the nodes.
47
        err = migrateNodes(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
48
        if err != nil {
×
49
                return fmt.Errorf("could not migrate nodes: %w", err)
×
50
        }
×
51

52
        // 2) Migrate the source node.
53
        if err := migrateSourceNode(ctx, kvBackend, sqlDB); err != nil {
×
54
                return fmt.Errorf("could not migrate source node: %w", err)
×
55
        }
×
56

57
        // 3) Migrate all the channels and channel policies.
58
        err = migrateChannelsAndPolicies(ctx, cfg, kvBackend, sqlDB)
×
59
        if err != nil {
×
60
                return fmt.Errorf("could not migrate channels and policies: %w",
×
61
                        err)
×
62
        }
×
63

64
        // 4) Migrate the Prune log.
65
        err = migratePruneLog(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
66
        if err != nil {
×
67
                return fmt.Errorf("could not migrate prune log: %w", err)
×
68
        }
×
69

70
        // 5) Migrate the closed SCID index.
71
        err = migrateClosedSCIDIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
72
        if err != nil {
×
73
                return fmt.Errorf("could not migrate closed SCID index: %w",
×
74
                        err)
×
75
        }
×
76

77
        // 6) Migrate the zombie index.
78
        err = migrateZombieIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
79
        if err != nil {
×
80
                return fmt.Errorf("could not migrate zombie index: %w", err)
×
81
        }
×
82

83
        log.Infof("Finished migration of the graph store from KV to SQL in %v",
×
84
                time.Since(t0))
×
85

×
86
        return nil
×
87
}
88

89
// checkGraphExists checks if the graph exists in the KV backend.
90
func checkGraphExists(db kvdb.Backend) (bool, error) {
×
91
        // Check if there is even a graph to migrate.
×
92
        err := db.View(func(tx kvdb.RTx) error {
×
93
                // Check for the existence of the node bucket which is a top
×
94
                // level bucket that would have been created on the initial
×
95
                // creation of the graph store.
×
96
                nodes := tx.ReadBucket(nodeBucket)
×
97
                if nodes == nil {
×
98
                        return ErrGraphNotFound
×
99
                }
×
100

101
                return nil
×
102
        }, func() {})
×
103
        if errors.Is(err, ErrGraphNotFound) {
×
104
                return false, nil
×
105
        } else if err != nil {
×
106
                return false, err
×
107
        }
×
108

109
        return true, nil
×
110
}
111

112
// migrateNodes migrates all nodes from the KV backend to the SQL database.
113
// It collects nodes in batches, inserts them individually, and then validates
114
// them in batches.
115
func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
116
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
117

×
118
        // Keep track of the number of nodes migrated and the number of
×
119
        // nodes skipped due to errors.
×
120
        var (
×
121
                totalTime = time.Now()
×
122

×
123
                count   uint64
×
124
                skipped uint64
×
125

×
126
                t0    = time.Now()
×
127
                chunk uint64
×
128
                s     = rate.Sometimes{
×
129
                        Interval: 10 * time.Second,
×
130
                }
×
131
        )
×
132

×
133
        // batch is a map that holds node objects that have been migrated to
×
134
        // the native SQL store that have yet to be validated. The object's held
×
135
        // by this map were derived from the KVDB store and so when they are
×
136
        // validated, the map index (the SQL store node ID) will be used to
×
137
        // fetch the corresponding node object in the SQL store, and it will
×
138
        // then be compared against the original KVDB node object.
×
139
        batch := make(
×
140
                map[int64]*models.Node, cfg.MaxBatchSize,
×
141
        )
×
142

×
143
        // validateBatch validates that the batch of nodes in the 'batch' map
×
144
        // have been migrated successfully.
×
145
        validateBatch := func() error {
×
146
                if len(batch) == 0 {
×
147
                        return nil
×
148
                }
×
149

150
                // Extract DB node IDs.
151
                dbIDs := make([]int64, 0, len(batch))
×
152
                for dbID := range batch {
×
153
                        dbIDs = append(dbIDs, dbID)
×
154
                }
×
155

156
                // Batch fetch all nodes from the database.
157
                dbNodes, err := sqlDB.GetNodesByIDs(ctx, dbIDs)
×
158
                if err != nil {
×
159
                        return fmt.Errorf("could not batch fetch nodes: %w",
×
160
                                err)
×
161
                }
×
162

163
                // Make sure that the number of nodes fetched matches the number
164
                // of nodes in the batch.
165
                if len(dbNodes) != len(batch) {
×
166
                        return fmt.Errorf("expected to fetch %d nodes, "+
×
167
                                "but got %d", len(batch), len(dbNodes))
×
168
                }
×
169

170
                // Now, batch fetch the normalised data for all the nodes in
171
                // the batch.
172
                batchData, err := batchLoadNodeData(ctx, cfg, sqlDB, dbIDs)
×
173
                if err != nil {
×
174
                        return fmt.Errorf("unable to batch load node data: %w",
×
175
                                err)
×
176
                }
×
177

178
                for _, dbNode := range dbNodes {
×
179
                        // Get the KVDB node info from the batch map.
×
180
                        node, ok := batch[dbNode.ID]
×
181
                        if !ok {
×
182
                                return fmt.Errorf("node with ID %d not found "+
×
183
                                        "in batch", dbNode.ID)
×
184
                        }
×
185

186
                        // Build the migrated node from the DB node and the
187
                        // batch node data.
188
                        migNode, err := buildNodeWithBatchData(
×
189
                                dbNode, batchData,
×
190
                        )
×
191
                        if err != nil {
×
192
                                return fmt.Errorf("could not build migrated "+
×
193
                                        "node from dbNode(db id: %d, node "+
×
194
                                        "pub: %x): %w", dbNode.ID,
×
195
                                        node.PubKeyBytes, err)
×
196
                        }
×
197

198
                        // Make sure that the node addresses are sorted before
199
                        // comparing them to ensure that the order of addresses
200
                        // does not affect the comparison.
201
                        slices.SortFunc(
×
202
                                node.Addresses, func(i, j net.Addr) int {
×
203
                                        return cmp.Compare(
×
204
                                                i.String(), j.String(),
×
205
                                        )
×
206
                                },
×
207
                        )
208
                        slices.SortFunc(
×
209
                                migNode.Addresses, func(i, j net.Addr) int {
×
210
                                        return cmp.Compare(
×
211
                                                i.String(), j.String(),
×
212
                                        )
×
213
                                },
×
214
                        )
215

216
                        err = sqldb.CompareRecords(
×
217
                                node, migNode,
×
218
                                fmt.Sprintf("node %x", node.PubKeyBytes),
×
219
                        )
×
220
                        if err != nil {
×
221
                                return fmt.Errorf("node mismatch after "+
×
222
                                        "migration for node %x: %w",
×
223
                                        node.PubKeyBytes, err)
×
224
                        }
×
225
                }
226

227
                // Clear the batch map for the next iteration.
228
                batch = make(
×
229
                        map[int64]*models.Node, cfg.MaxBatchSize,
×
230
                )
×
231

×
232
                return nil
×
233
        }
234

235
        // Loop through each node in the KV store and insert it into the SQL
236
        // database.
237
        err := forEachNode(kvBackend, func(_ kvdb.RTx,
×
238
                node *models.Node) error {
×
239

×
240
                pub := node.PubKeyBytes
×
241

×
242
                // Sanity check to ensure that the node has valid extra opaque
×
243
                // data. If it does not, we'll skip it. We need to do this
×
244
                // because previously we would just persist any TLV bytes that
×
245
                // we received without validating them. Now, however, we
×
246
                // normalise the storage of extra opaque data, so we need to
×
247
                // ensure that the data is valid. We don't want to abort the
×
248
                // migration if we encounter a node with invalid extra opaque
×
249
                // data, so we'll just skip it and log a warning.
×
250
                _, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
251
                if errors.Is(err, ErrParsingExtraTLVBytes) {
×
252
                        skipped++
×
253
                        log.Warnf("Skipping migration of node %x with invalid "+
×
254
                                "extra opaque data: %v", pub,
×
255
                                node.ExtraOpaqueData)
×
256

×
257
                        return nil
×
258
                } else if err != nil {
×
259
                        return fmt.Errorf("unable to marshal extra "+
×
260
                                "opaque data for node %x: %w", pub, err)
×
261
                }
×
262

263
                if err = maybeOverrideNodeAddresses(node); err != nil {
×
264
                        skipped++
×
265
                        log.Warnf("Skipping migration of node %x with invalid "+
×
266
                                "address (%v): %v", pub, node.Addresses, err)
×
267

×
268
                        return nil
×
269
                }
×
270

271
                count++
×
272
                chunk++
×
273

×
274
                // Write the node to the SQL database.
×
275
                id, err := insertNodeSQLMig(ctx, sqlDB, node)
×
276
                if err != nil {
×
277
                        return fmt.Errorf("could not persist node(%x): %w", pub,
×
278
                                err)
×
279
                }
×
280

281
                // Add to validation batch.
UNCOV
282
                batch[id] = node
×
283

×
284
                // Validate batch when full.
×
285
                if len(batch) >= int(cfg.MaxBatchSize) {
×
286
                        err := validateBatch()
×
287
                        if err != nil {
×
288
                                return fmt.Errorf("batch validation failed: %w",
×
289
                                        err)
×
290
                        }
×
291
                }
292

UNCOV
293
                s.Do(func() {
×
294
                        elapsed := time.Since(t0).Seconds()
×
295
                        ratePerSec := float64(chunk) / elapsed
×
296
                        log.Debugf("Migrated %d nodes (%.2f nodes/sec)",
×
297
                                count, ratePerSec)
×
298

×
299
                        t0 = time.Now()
×
300
                        chunk = 0
×
301
                })
×
302

UNCOV
303
                return nil
×
304
        }, func() {
×
305
                count = 0
×
306
                chunk = 0
×
307
                skipped = 0
×
308
                t0 = time.Now()
×
309
                batch = make(map[int64]*models.Node, cfg.MaxBatchSize)
×
310
        })
×
311
        if err != nil {
×
312
                return fmt.Errorf("could not migrate nodes: %w", err)
×
313
        }
×
314

315
        // Validate any remaining nodes in the batch.
UNCOV
316
        if len(batch) > 0 {
×
317
                err := validateBatch()
×
318
                if err != nil {
×
319
                        return fmt.Errorf("final batch validation failed: %w",
×
320
                                err)
×
321
                }
×
322
        }
323

UNCOV
324
        log.Infof("Migrated %d nodes from KV to SQL in %v (skipped %d nodes "+
×
325
                "due to invalid TLV streams or invalid addresses)", count,
×
326
                time.Since(totalTime),
×
327

×
328
                skipped)
×
329

×
UNCOV
330
        return nil
×
331
}
332

333
// maybeOverrideNodeAddresses checks if the node has any opaque addresses that
334
// can be parsed. If so, it replaces the node's addresses with the parsed
335
// addresses. If the address is unparseable, it returns an error.
NEW
336
func maybeOverrideNodeAddresses(node *models.Node) error {
×
337
        // In the majority of cases, the number of node addresses will remain
×
338
        // unchanged, so we pre-allocate a slice of the same length.
×
339
        addrs := make([]net.Addr, 0, len(node.Addresses))
×
340

×
341
        // Iterate over each address in search of any opaque addresses that we
×
342
        // can inspect.
×
343
        for _, addr := range node.Addresses {
×
344
                opaque, ok := addr.(*lnwire.OpaqueAddrs)
×
345
                if !ok {
×
346
                        // Any non-opaque address is left unchanged.
×
347
                        addrs = append(addrs, addr)
×
UNCOV
348
                        continue
×
349
                }
350

351
                // For each opaque address, we'll now attempt to parse out any
352
                // known addresses. We'll do this in a loop, as it's possible
353
                // that there are several addresses encoded in a single opaque
354
                // address.
355
                payload := opaque.Payload
×
356
                for len(payload) > 0 {
×
357
                        var (
×
358
                                r            = bytes.NewReader(payload)
×
359
                                numAddrBytes = uint16(len(payload))
×
360
                        )
×
361
                        byteRead, readAddr, err := lnwire.ReadAddress(
×
362
                                r, numAddrBytes,
×
UNCOV
363
                        )
×
UNCOV
364
                        if err != nil {
×
365
                                return err
×
366
                        }
×
367

368
                        // If we were able to read an address, we'll add it to
369
                        // our list of addresses.
UNCOV
370
                        if readAddr != nil {
×
UNCOV
371
                                addrs = append(addrs, readAddr)
×
UNCOV
372
                        }
×
373

374
                        // If the address we read was an opaque address, it
375
                        // means we've hit an unknown address type, and it has
376
                        // consumed the rest of the payload. We can break out
377
                        // of the loop.
378
                        if _, ok := readAddr.(*lnwire.OpaqueAddrs); ok {
×
379
                                break
×
380
                        }
381

382
                        // If we've read all the bytes, we can also break.
UNCOV
383
                        if byteRead >= numAddrBytes {
×
384
                                break
×
385
                        }
386

387
                        // Otherwise, we'll advance our payload slice and
388
                        // continue.
UNCOV
389
                        payload = payload[byteRead:]
×
390
                }
391
        }
392

393
        // Override the node addresses if we have any.
394
        if len(addrs) != 0 {
×
395
                node.Addresses = addrs
×
396
        }
×
397

398
        return nil
×
399
}
400

401
// migrateSourceNode migrates the source node from the KV backend to the
402
// SQL database.
403
func migrateSourceNode(ctx context.Context, kvdb kvdb.Backend,
UNCOV
404
        sqlDB SQLQueries) error {
×
UNCOV
405

×
UNCOV
406
        log.Debugf("Migrating source node from KV to SQL")
×
UNCOV
407

×
UNCOV
408
        sourceNode, err := sourceNode(kvdb)
×
UNCOV
409
        if errors.Is(err, ErrSourceNodeNotSet) {
×
UNCOV
410
                // If the source node has not been set yet, we can skip this
×
UNCOV
411
                // migration step.
×
UNCOV
412
                return nil
×
UNCOV
413
        } else if err != nil {
×
UNCOV
414
                return fmt.Errorf("could not get source node from kv "+
×
UNCOV
415
                        "store: %w", err)
×
UNCOV
416
        }
×
417

UNCOV
418
        pub := sourceNode.PubKeyBytes
×
UNCOV
419

×
UNCOV
420
        // Get the DB ID of the source node by its public key. This node must
×
UNCOV
421
        // already exist in the SQL database, as it should have been migrated
×
UNCOV
422
        // in the previous node-migration step.
×
UNCOV
423
        id, err := sqlDB.GetNodeIDByPubKey(
×
UNCOV
424
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
UNCOV
425
                        PubKey:  pub[:],
×
UNCOV
426
                        Version: int16(ProtocolV1),
×
427
                },
×
428
        )
×
429
        if err != nil {
×
430
                return fmt.Errorf("could not get source node ID: %w", err)
×
431
        }
×
432

433
        // Now we can add the source node to the SQL database.
434
        err = sqlDB.AddSourceNode(ctx, id)
×
435
        if err != nil {
×
436
                return fmt.Errorf("could not add source node to SQL store: %w",
×
437
                        err)
×
438
        }
×
439

440
        // Verify that the source node was added correctly by fetching it back
441
        // from the SQL database and checking that the expected DB ID and
442
        // pub key are returned. We don't need to do a whole node comparison
443
        // here, as this was already done in the previous migration step.
444
        srcNodes, err := sqlDB.GetSourceNodesByVersion(ctx, int16(ProtocolV1))
×
445
        if err != nil {
×
446
                return fmt.Errorf("could not get source nodes from SQL "+
×
447
                        "store: %w", err)
×
448
        }
×
449

450
        // The SQL store has support for multiple source nodes (for future
451
        // protocol versions) but this migration is purely aimed at the V1
452
        // store, and so we expect exactly one source node to be present.
UNCOV
453
        if len(srcNodes) != 1 {
×
UNCOV
454
                return fmt.Errorf("expected exactly one source node, "+
×
UNCOV
455
                        "got %d", len(srcNodes))
×
UNCOV
456
        }
×
457

458
        // Check that the source node ID and pub key match the original
459
        // source node.
460
        if srcNodes[0].NodeID != id {
×
461
                return fmt.Errorf("source node ID mismatch after migration: "+
×
462
                        "expected %d, got %d", id, srcNodes[0].NodeID)
×
463
        }
×
UNCOV
464
        err = sqldb.CompareRecords(pub[:], srcNodes[0].PubKey, "source node")
×
465
        if err != nil {
×
UNCOV
466
                return fmt.Errorf("source node pubkey mismatch after "+
×
UNCOV
467
                        "migration: %w", err)
×
UNCOV
468
        }
×
469

UNCOV
470
        log.Infof("Migrated source node with pubkey %x to SQL", pub[:])
×
471

×
472
        return nil
×
473
}
474

475
// migChanInfo holds the information about a channel and its policies.
476
type migChanInfo struct {
477
        // edge is the channel object as read from the KVDB source.
478
        edge *models.ChannelEdgeInfo
479

480
        // policy1 is the first channel policy for the channel as read from
481
        // the KVDB source.
482
        policy1 *models.ChannelEdgePolicy
483

484
        // policy2 is the second channel policy for the channel as read
485
        // from the KVDB source.
486
        policy2 *models.ChannelEdgePolicy
487

488
        // dbInfo holds location info (in the form of DB IDs) of the channel
489
        // and its policies in the native-SQL destination.
490
        dbInfo *dbChanInfo
491
}
492

493
// migrateChannelsAndPolicies migrates all channels and their policies
494
// from the KV backend to the SQL database.
495
func migrateChannelsAndPolicies(ctx context.Context, cfg *SQLStoreConfig,
UNCOV
496
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
UNCOV
497

×
UNCOV
498
        var (
×
UNCOV
499
                totalTime = time.Now()
×
UNCOV
500

×
UNCOV
501
                channelCount       uint64
×
502
                skippedChanCount   uint64
×
503
                policyCount        uint64
×
504
                skippedPolicyCount uint64
×
505

×
506
                t0    = time.Now()
×
507
                chunk uint64
×
508
                s     = rate.Sometimes{
×
509
                        Interval: 10 * time.Second,
×
510
                }
×
511
        )
×
512
        migChanPolicy := func(dbChanInfo *dbChanInfo,
×
513
                policy *models.ChannelEdgePolicy) error {
×
514

×
515
                // If the policy is nil, we can skip it.
×
516
                if policy == nil {
×
UNCOV
517
                        return nil
×
518
                }
×
519

520
                // Unlike the special case of invalid TLV bytes for node and
521
                // channel announcements, we don't need to handle the case for
522
                // channel policies here because it is already handled in the
523
                // `forEachChannel` function. If the policy has invalid TLV
524
                // bytes, then `nil` will be passed to this function.
525

526
                policyCount++
×
527

×
528
                err := insertChanEdgePolicyMig(ctx, sqlDB, dbChanInfo, policy)
×
529
                if err != nil {
×
530
                        return fmt.Errorf("could not migrate channel "+
×
531
                                "policy %d: %w", policy.ChannelID, err)
×
532
                }
×
533

UNCOV
534
                return nil
×
535
        }
536

537
        // batch is used to collect migrated channel info that we will
538
        // batch-validate. Each entry is indexed by the DB ID of the channel
539
        // in the SQL database.
540
        batch := make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
541

×
542
        // Iterate over each channel in the KV store and migrate it and its
×
543
        // policies to the SQL database.
×
544
        err := forEachChannel(kvBackend, func(channel *models.ChannelEdgeInfo,
×
545
                policy1 *models.ChannelEdgePolicy,
×
UNCOV
546
                policy2 *models.ChannelEdgePolicy) error {
×
UNCOV
547

×
UNCOV
548
                scid := channel.ChannelID
×
549

×
550
                // Here, we do a sanity check to ensure that the chain hash of
×
551
                // the channel returned by the KV store matches the expected
×
552
                // chain hash. This is important since in the SQL store, we will
×
553
                // no longer explicitly store the chain hash in the channel
×
554
                // info, but rather rely on the chain hash LND is running with.
×
555
                // So this is our way of ensuring that LND is running on the
×
556
                // correct network at migration time.
×
557
                if channel.ChainHash != cfg.ChainHash {
×
558
                        return fmt.Errorf("channel %d has chain hash %s, "+
×
559
                                "expected %s", scid, channel.ChainHash,
×
560
                                cfg.ChainHash)
×
561
                }
×
562

563
                // Sanity check to ensure that the channel has valid extra
564
                // opaque data. If it does not, we'll skip it. We need to do
565
                // this because previously we would just persist any TLV bytes
566
                // that we received without validating them. Now, however, we
567
                // normalise the storage of extra opaque data, so we need to
568
                // ensure that the data is valid. We don't want to abort the
569
                // migration if we encounter a channel with invalid extra opaque
570
                // data, so we'll just skip it and log a warning.
571
                _, err := marshalExtraOpaqueData(channel.ExtraOpaqueData)
×
572
                if errors.Is(err, ErrParsingExtraTLVBytes) {
×
573
                        log.Warnf("Skipping channel %d with invalid "+
×
574
                                "extra opaque data: %v", scid,
×
575
                                channel.ExtraOpaqueData)
×
576

×
577
                        skippedChanCount++
×
578

×
UNCOV
579
                        // If we skip a channel, we also skip its policies.
×
580
                        if policy1 != nil {
×
581
                                skippedPolicyCount++
×
582
                        }
×
583
                        if policy2 != nil {
×
584
                                skippedPolicyCount++
×
585
                        }
×
586

587
                        return nil
×
588
                } else if err != nil {
×
589
                        return fmt.Errorf("unable to marshal extra opaque "+
×
590
                                "data for channel %d (%v): %w", scid,
×
591
                                channel.ExtraOpaqueData, err)
×
592
                }
×
593

UNCOV
594
                channelCount++
×
595
                chunk++
×
596

×
597
                // Migrate the channel info along with its policies.
×
598
                dbChanInfo, err := insertChannelMig(ctx, sqlDB, channel)
×
599
                if err != nil {
×
600
                        return fmt.Errorf("could not insert record for "+
×
601
                                "channel %d in SQL store: %w", scid, err)
×
UNCOV
602
                }
×
603

604
                // Now, migrate the two channel policies for the channel.
UNCOV
605
                err = migChanPolicy(dbChanInfo, policy1)
×
606
                if err != nil {
×
607
                        return fmt.Errorf("could not migrate policy1(%d): %w",
×
608
                                scid, err)
×
609
                }
×
610
                err = migChanPolicy(dbChanInfo, policy2)
×
611
                if err != nil {
×
UNCOV
612
                        return fmt.Errorf("could not migrate policy2(%d): %w",
×
UNCOV
613
                                scid, err)
×
UNCOV
614
                }
×
615

616
                // Collect the migrated channel info and policies in a batch for
617
                // later validation.
UNCOV
618
                batch[dbChanInfo.channelID] = &migChanInfo{
×
619
                        edge:    channel,
×
620
                        policy1: policy1,
×
621
                        policy2: policy2,
×
622
                        dbInfo:  dbChanInfo,
×
623
                }
×
624

×
625
                if len(batch) >= int(cfg.QueryCfg.MaxBatchSize) {
×
UNCOV
626
                        // Do batch validation.
×
UNCOV
627
                        err := validateMigratedChannels(ctx, cfg, sqlDB, batch)
×
628
                        if err != nil {
×
629
                                return fmt.Errorf("could not validate "+
×
630
                                        "channel batch: %w", err)
×
631
                        }
×
632

UNCOV
633
                        batch = make(
×
UNCOV
634
                                map[int64]*migChanInfo,
×
UNCOV
635
                                cfg.QueryCfg.MaxBatchSize,
×
636
                        )
×
637
                }
638

639
                s.Do(func() {
×
UNCOV
640
                        elapsed := time.Since(t0).Seconds()
×
UNCOV
641
                        ratePerSec := float64(chunk) / elapsed
×
642
                        log.Debugf("Migrated %d channels (%.2f channels/sec)",
×
643
                                channelCount, ratePerSec)
×
644

×
645
                        t0 = time.Now()
×
646
                        chunk = 0
×
647
                })
×
648

649
                return nil
×
650
        }, func() {
×
651
                channelCount = 0
×
652
                policyCount = 0
×
653
                chunk = 0
×
654
                skippedChanCount = 0
×
655
                skippedPolicyCount = 0
×
656
                t0 = time.Now()
×
657
                batch = make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
UNCOV
658
        })
×
UNCOV
659
        if err != nil {
×
UNCOV
660
                return fmt.Errorf("could not migrate channels and policies: %w",
×
661
                        err)
×
662
        }
×
663

664
        if len(batch) > 0 {
×
665
                // Do a final batch validation for any remaining channels.
×
666
                err := validateMigratedChannels(ctx, cfg, sqlDB, batch)
×
667
                if err != nil {
×
UNCOV
668
                        return fmt.Errorf("could not validate final channel "+
×
UNCOV
669
                                "batch: %w", err)
×
670
                }
×
671

672
                batch = make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
673
        }
674

675
        log.Infof("Migrated %d channels and %d policies from KV to SQL in %s"+
×
UNCOV
676
                "(skipped %d channels and %d policies due to invalid TLV "+
×
677
                "streams)", channelCount, policyCount, time.Since(totalTime),
×
678
                skippedChanCount, skippedPolicyCount)
×
679

×
680
        return nil
×
681
}
682

683
// validateMigratedChannels validates the channels in the batch after they have
684
// been migrated to the SQL database. It batch fetches all channels by their IDs
685
// and compares the migrated channels and their policies with the original ones
686
// to ensure they match using batch construction patterns.
687
func validateMigratedChannels(ctx context.Context, cfg *SQLStoreConfig,
688
        sqlDB SQLQueries, batch map[int64]*migChanInfo) error {
×
UNCOV
689

×
UNCOV
690
        // Convert batch keys (DB IDs) to an int slice for the batch query.
×
UNCOV
691
        dbChanIDs := make([]int64, 0, len(batch))
×
UNCOV
692
        for id := range batch {
×
UNCOV
693
                dbChanIDs = append(dbChanIDs, id)
×
UNCOV
694
        }
×
695

696
        // Batch fetch all channels with their policies.
697
        rows, err := sqlDB.GetChannelsByIDs(ctx, dbChanIDs)
×
698
        if err != nil {
×
699
                return fmt.Errorf("could not batch get channels by IDs: %w",
×
700
                        err)
×
701
        }
×
702

703
        // Sanity check that the same number of channels were returned
704
        // as requested.
705
        if len(rows) != len(dbChanIDs) {
×
706
                return fmt.Errorf("expected to fetch %d channels, "+
×
707
                        "but got %d", len(dbChanIDs), len(rows))
×
708
        }
×
709

710
        // Collect all policy IDs needed for batch data loading.
711
        dbPolicyIDs := make([]int64, 0, len(dbChanIDs)*2)
×
712

×
713
        for _, row := range rows {
×
714
                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
715

×
716
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
717
                if err != nil {
×
718
                        return fmt.Errorf("could not extract channel policies"+
×
719
                                " for SCID %d: %w", scid, err)
×
UNCOV
720
                }
×
UNCOV
721
                if dbPol1 != nil {
×
722
                        dbPolicyIDs = append(dbPolicyIDs, dbPol1.ID)
×
723
                }
×
724
                if dbPol2 != nil {
×
725
                        dbPolicyIDs = append(dbPolicyIDs, dbPol2.ID)
×
726
                }
×
727
        }
728

729
        // Batch load all channel and policy data (features, extras).
730
        batchData, err := batchLoadChannelData(
×
731
                ctx, cfg.QueryCfg, sqlDB, dbChanIDs, dbPolicyIDs,
×
732
        )
×
733
        if err != nil {
×
734
                return fmt.Errorf("could not batch load channel and policy "+
×
735
                        "data: %w", err)
×
736
        }
×
737

738
        // Validate each channel in the batch using pre-loaded data.
739
        for _, row := range rows {
×
740
                kvdbChan, ok := batch[row.GraphChannel.ID]
×
741
                if !ok {
×
742
                        return fmt.Errorf("channel with ID %d not found "+
×
UNCOV
743
                                "in batch", row.GraphChannel.ID)
×
UNCOV
744
                }
×
745

746
                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
747

×
748
                err = validateMigratedChannelWithBatchData(
×
749
                        cfg, scid, kvdbChan, row, batchData,
×
750
                )
×
751
                if err != nil {
×
UNCOV
752
                        return fmt.Errorf("channel %d validation failed "+
×
UNCOV
753
                                "after migration: %w", scid, err)
×
UNCOV
754
                }
×
755
        }
756

757
        return nil
×
758
}
759

760
// validateMigratedChannelWithBatchData validates a single migrated channel
761
// using pre-fetched batch data for optimal performance.
762
func validateMigratedChannelWithBatchData(cfg *SQLStoreConfig,
763
        scid uint64, info *migChanInfo, row sqlc.GetChannelsByIDsRow,
764
        batchData *batchChannelData) error {
×
765

×
766
        dbChanInfo := info.dbInfo
×
767
        channel := info.edge
×
UNCOV
768

×
769
        // Assert that the DB IDs for the channel and nodes are as expected
×
770
        // given the inserted channel info.
×
771
        err := sqldb.CompareRecords(
×
772
                dbChanInfo.channelID, row.GraphChannel.ID, "channel DB ID",
×
UNCOV
773
        )
×
774
        if err != nil {
×
775
                return err
×
UNCOV
776
        }
×
UNCOV
777
        err = sqldb.CompareRecords(
×
778
                dbChanInfo.node1ID, row.Node1ID, "node1 DB ID",
×
779
        )
×
780
        if err != nil {
×
781
                return err
×
UNCOV
782
        }
×
UNCOV
783
        err = sqldb.CompareRecords(
×
784
                dbChanInfo.node2ID, row.Node2ID, "node2 DB ID",
×
UNCOV
785
        )
×
UNCOV
786
        if err != nil {
×
787
                return err
×
788
        }
×
789

790
        // Build node vertices from the row data.
791
        node1, node2, err := buildNodeVertices(
×
792
                row.Node1PubKey, row.Node2PubKey,
×
UNCOV
793
        )
×
794
        if err != nil {
×
795
                return err
×
796
        }
×
797

798
        // Build channel info using batch data.
799
        migChan, err := buildEdgeInfoWithBatchData(
×
800
                cfg.ChainHash, row.GraphChannel, node1, node2, batchData,
×
801
        )
×
802
        if err != nil {
×
803
                return fmt.Errorf("could not build migrated channel info: %w",
×
UNCOV
804
                        err)
×
805
        }
×
806

807
        // Extract channel policies from the row.
808
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
809
        if err != nil {
×
UNCOV
810
                return fmt.Errorf("could not extract channel policies: %w", err)
×
811
        }
×
812

813
        // Build channel policies using batch data.
UNCOV
814
        migPol1, migPol2, err := buildChanPoliciesWithBatchData(
×
UNCOV
815
                dbPol1, dbPol2, scid, node1, node2, batchData,
×
UNCOV
816
        )
×
UNCOV
817
        if err != nil {
×
UNCOV
818
                return fmt.Errorf("could not build migrated channel "+
×
819
                        "policies: %w", err)
×
820
        }
×
821

822
        // Finally, compare the original channel info and
823
        // policies with the migrated ones to ensure they match.
824
        if len(channel.ExtraOpaqueData) == 0 {
×
825
                channel.ExtraOpaqueData = nil
×
826
        }
×
827
        if len(migChan.ExtraOpaqueData) == 0 {
×
828
                migChan.ExtraOpaqueData = nil
×
829
        }
×
830

831
        err = sqldb.CompareRecords(
×
832
                channel, migChan, fmt.Sprintf("channel %d", scid),
×
833
        )
×
834
        if err != nil {
×
835
                return err
×
836
        }
×
837

838
        checkPolicy := func(expPolicy,
×
839
                migPolicy *models.ChannelEdgePolicy) error {
×
840

×
841
                switch {
×
842
                // Both policies are nil, nothing to compare.
UNCOV
843
                case expPolicy == nil && migPolicy == nil:
×
844
                        return nil
×
845

846
                // One of the policies is nil, but the other is not.
847
                case expPolicy == nil || migPolicy == nil:
×
UNCOV
848
                        return fmt.Errorf("expected both policies to be "+
×
UNCOV
849
                                "non-nil. Got expPolicy: %v, "+
×
850
                                "migPolicy: %v", expPolicy, migPolicy)
×
851

852
                // Both policies are non-nil, we can compare them.
853
                default:
×
854
                }
855

856
                if len(expPolicy.ExtraOpaqueData) == 0 {
×
857
                        expPolicy.ExtraOpaqueData = nil
×
858
                }
×
859
                if len(migPolicy.ExtraOpaqueData) == 0 {
×
860
                        migPolicy.ExtraOpaqueData = nil
×
UNCOV
861
                }
×
862

863
                return sqldb.CompareRecords(
×
864
                        *expPolicy, *migPolicy, "channel policy",
×
865
                )
×
866
        }
867

868
        err = checkPolicy(info.policy1, migPol1)
×
UNCOV
869
        if err != nil {
×
870
                return fmt.Errorf("policy1 mismatch for channel %d: %w", scid,
×
871
                        err)
×
872
        }
×
873

874
        err = checkPolicy(info.policy2, migPol2)
×
875
        if err != nil {
×
876
                return fmt.Errorf("policy2 mismatch for channel %d: %w", scid,
×
877
                        err)
×
UNCOV
878
        }
×
879

UNCOV
880
        return nil
×
881
}
882

883
// migratePruneLog migrates the prune log from the KV backend to the SQL
884
// database. It collects entries in batches, inserts them individually, and then
885
// validates them in batches using GetPruneEntriesForHeights for better i
886
// performance.
887
func migratePruneLog(ctx context.Context, cfg *sqldb.QueryConfig,
888
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
889

×
890
        var (
×
891
                totalTime = time.Now()
×
892

×
893
                count          uint64
×
894
                pruneTipHeight uint32
×
895
                pruneTipHash   chainhash.Hash
×
896

×
897
                t0    = time.Now()
×
UNCOV
898
                chunk uint64
×
UNCOV
899
                s     = rate.Sometimes{
×
900
                        Interval: 10 * time.Second,
×
901
                }
×
902
        )
×
903

×
904
        batch := make(map[uint32]chainhash.Hash, cfg.MaxBatchSize)
×
905

×
906
        // validateBatch validates a batch of prune entries using batch query.
×
907
        validateBatch := func() error {
×
908
                if len(batch) == 0 {
×
909
                        return nil
×
UNCOV
910
                }
×
911

912
                // Extract heights for the batch query.
913
                heights := make([]int64, 0, len(batch))
×
914
                for height := range batch {
×
915
                        heights = append(heights, int64(height))
×
916
                }
×
917

918
                // Batch fetch all entries from the database.
919
                rows, err := sqlDB.GetPruneEntriesForHeights(ctx, heights)
×
920
                if err != nil {
×
UNCOV
921
                        return fmt.Errorf("could not batch get prune "+
×
UNCOV
922
                                "entries: %w", err)
×
923
                }
×
924

925
                if len(rows) != len(batch) {
×
926
                        return fmt.Errorf("expected to fetch %d prune "+
×
927
                                "entries, but got %d", len(batch),
×
928
                                len(rows))
×
929
                }
×
930

931
                // Validate each entry in the batch.
932
                for _, row := range rows {
×
UNCOV
933
                        kvdbHash, ok := batch[uint32(row.BlockHeight)]
×
934
                        if !ok {
×
UNCOV
935
                                return fmt.Errorf("prune entry for height %d "+
×
936
                                        "not found in batch", row.BlockHeight)
×
937
                        }
×
938

939
                        err := sqldb.CompareRecords(
×
940
                                kvdbHash[:], row.BlockHash,
×
941
                                fmt.Sprintf("prune log entry at height %d",
×
942
                                        row.BlockHash),
×
943
                        )
×
UNCOV
944
                        if err != nil {
×
945
                                return err
×
946
                        }
×
947
                }
948

949
                // Reset the batch map for the next iteration.
950
                batch = make(map[uint32]chainhash.Hash, cfg.MaxBatchSize)
×
951

×
952
                return nil
×
953
        }
954

955
        // Iterate over each prune log entry in the KV store and migrate it to
956
        // the SQL database.
UNCOV
957
        err := forEachPruneLogEntry(
×
UNCOV
958
                kvBackend, func(height uint32, hash *chainhash.Hash) error {
×
UNCOV
959
                        count++
×
960
                        chunk++
×
961

×
962
                        // Keep track of the prune tip height and hash.
×
963
                        if height > pruneTipHeight {
×
964
                                pruneTipHeight = height
×
965
                                pruneTipHash = *hash
×
966
                        }
×
967

968
                        // Insert the entry (individual inserts for now).
969
                        err := sqlDB.UpsertPruneLogEntry(
×
UNCOV
970
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
UNCOV
971
                                        BlockHeight: int64(height),
×
972
                                        BlockHash:   hash[:],
×
973
                                },
×
974
                        )
×
UNCOV
975
                        if err != nil {
×
976
                                return fmt.Errorf("unable to insert prune log "+
×
977
                                        "entry for height %d: %w", height, err)
×
978
                        }
×
979

980
                        // Add to validation batch.
981
                        batch[height] = *hash
×
982

×
983
                        // Validate batch when full.
×
984
                        if len(batch) >= int(cfg.MaxBatchSize) {
×
UNCOV
985
                                err := validateBatch()
×
986
                                if err != nil {
×
987
                                        return fmt.Errorf("batch "+
×
988
                                                "validation failed: %w", err)
×
989
                                }
×
990
                        }
991

UNCOV
992
                        s.Do(func() {
×
UNCOV
993
                                elapsed := time.Since(t0).Seconds()
×
UNCOV
994
                                ratePerSec := float64(chunk) / elapsed
×
UNCOV
995
                                log.Debugf("Migrated %d prune log "+
×
996
                                        "entries (%.2f entries/sec)",
×
997
                                        count, ratePerSec)
×
998

×
999
                                t0 = time.Now()
×
1000
                                chunk = 0
×
1001
                        })
×
1002

UNCOV
1003
                        return nil
×
1004
                },
1005
                func() {
×
1006
                        count = 0
×
1007
                        chunk = 0
×
1008
                        t0 = time.Now()
×
1009
                        batch = make(
×
UNCOV
1010
                                map[uint32]chainhash.Hash, cfg.MaxBatchSize,
×
1011
                        )
×
1012
                },
×
1013
        )
1014
        if err != nil {
×
1015
                return fmt.Errorf("could not migrate prune log: %w", err)
×
1016
        }
×
1017

1018
        // Validate any remaining entries in the batch.
UNCOV
1019
        if len(batch) > 0 {
×
UNCOV
1020
                err := validateBatch()
×
UNCOV
1021
                if err != nil {
×
UNCOV
1022
                        return fmt.Errorf("final batch validation failed: %w",
×
UNCOV
1023
                                err)
×
UNCOV
1024
                }
×
1025
        }
1026

1027
        // Check that the prune tip is set correctly in the SQL
1028
        // database.
1029
        pruneTip, err := sqlDB.GetPruneTip(ctx)
×
1030
        if errors.Is(err, sql.ErrNoRows) {
×
1031
                // The ErrGraphNeverPruned error is expected if no prune log
×
1032
                // entries were migrated from the kvdb store. Otherwise, it's
×
1033
                // an unexpected error.
×
1034
                if count == 0 {
×
1035
                        log.Infof("No prune log entries found in KV store " +
×
1036
                                "to migrate")
×
1037
                        return nil
×
1038
                }
×
1039
                // Fall-through to the next error check.
1040
        }
1041
        if err != nil {
×
1042
                return fmt.Errorf("could not get prune tip: %w", err)
×
1043
        }
×
1044

1045
        if pruneTip.BlockHeight != int64(pruneTipHeight) ||
×
1046
                !bytes.Equal(pruneTip.BlockHash, pruneTipHash[:]) {
×
UNCOV
1047

×
UNCOV
1048
                return fmt.Errorf("prune tip mismatch after migration: "+
×
1049
                        "expected height %d, hash %s; got height %d, "+
×
1050
                        "hash %s", pruneTipHeight, pruneTipHash,
×
1051
                        pruneTip.BlockHeight,
×
1052
                        chainhash.Hash(pruneTip.BlockHash))
×
1053
        }
×
1054

UNCOV
1055
        log.Infof("Migrated %d prune log entries from KV to SQL in %s. "+
×
UNCOV
1056
                "The prune tip is: height %d, hash: %s", count,
×
1057
                time.Since(totalTime), pruneTipHeight, pruneTipHash)
×
1058

×
1059
        return nil
×
1060
}
1061

1062
// forEachPruneLogEntry iterates over each prune log entry in the KV
1063
// backend and calls the provided callback function for each entry.
1064
func forEachPruneLogEntry(db kvdb.Backend, cb func(height uint32,
1065
        hash *chainhash.Hash) error, reset func()) error {
×
1066

×
1067
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
UNCOV
1068
                metaBucket := tx.ReadBucket(graphMetaBucket)
×
UNCOV
1069
                if metaBucket == nil {
×
UNCOV
1070
                        return ErrGraphNotFound
×
1071
                }
×
1072

1073
                pruneBucket := metaBucket.NestedReadBucket(pruneLogBucket)
×
UNCOV
1074
                if pruneBucket == nil {
×
UNCOV
1075
                        // The graph has never been pruned and so, there are no
×
1076
                        // entries to iterate over.
×
1077
                        return nil
×
1078
                }
×
1079

1080
                return pruneBucket.ForEach(func(k, v []byte) error {
×
1081
                        blockHeight := byteOrder.Uint32(k)
×
1082
                        var blockHash chainhash.Hash
×
1083
                        copy(blockHash[:], v)
×
1084

×
1085
                        return cb(blockHeight, &blockHash)
×
UNCOV
1086
                })
×
1087
        }, reset)
1088
}
1089

1090
// migrateClosedSCIDIndex migrates the closed SCID index from the KV backend to
1091
// the SQL database. It collects SCIDs in batches, inserts them individually,
1092
// and then validates them in batches using GetClosedChannelsSCIDs for better
1093
// performance.
1094
func migrateClosedSCIDIndex(ctx context.Context, cfg *sqldb.QueryConfig,
1095
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
1096

×
UNCOV
1097
        var (
×
UNCOV
1098
                totalTime = time.Now()
×
1099

×
1100
                count uint64
×
1101

×
1102
                t0    = time.Now()
×
1103
                chunk uint64
×
1104
                s     = rate.Sometimes{
×
1105
                        Interval: 10 * time.Second,
×
1106
                }
×
1107
        )
×
UNCOV
1108

×
1109
        batch := make([][]byte, 0, cfg.MaxBatchSize)
×
UNCOV
1110

×
UNCOV
1111
        // validateBatch validates a batch of closed SCIDs using batch query.
×
1112
        validateBatch := func() error {
×
1113
                if len(batch) == 0 {
×
1114
                        return nil
×
1115
                }
×
1116

1117
                // Batch fetch all closed SCIDs from the database.
1118
                dbSCIDs, err := sqlDB.GetClosedChannelsSCIDs(ctx, batch)
×
UNCOV
1119
                if err != nil {
×
1120
                        return fmt.Errorf("could not batch get closed "+
×
1121
                                "SCIDs: %w", err)
×
1122
                }
×
1123

1124
                // Create set of SCIDs that exist in the database for quick
1125
                // lookup.
1126
                dbSCIDSet := make(map[string]struct{})
×
1127
                for _, scid := range dbSCIDs {
×
1128
                        dbSCIDSet[string(scid)] = struct{}{}
×
1129
                }
×
1130

1131
                // Validate each SCID in the batch.
UNCOV
1132
                for _, expectedSCID := range batch {
×
UNCOV
1133
                        if _, found := dbSCIDSet[string(expectedSCID)]; !found {
×
1134
                                return fmt.Errorf("closed SCID %x not found "+
×
1135
                                        "in database", expectedSCID)
×
1136
                        }
×
1137
                }
1138

1139
                // Reset the batch for the next iteration.
UNCOV
1140
                batch = make([][]byte, 0, cfg.MaxBatchSize)
×
UNCOV
1141

×
UNCOV
1142
                return nil
×
1143
        }
1144

UNCOV
1145
        migrateSingleClosedSCID := func(scid lnwire.ShortChannelID) error {
×
UNCOV
1146
                count++
×
UNCOV
1147
                chunk++
×
UNCOV
1148

×
UNCOV
1149
                chanIDB := channelIDToBytes(scid.ToUint64())
×
1150
                err := sqlDB.InsertClosedChannel(ctx, chanIDB)
×
1151
                if err != nil {
×
1152
                        return fmt.Errorf("could not insert closed channel "+
×
1153
                                "with SCID %s: %w", scid, err)
×
1154
                }
×
1155

1156
                // Add to validation batch.
1157
                batch = append(batch, chanIDB)
×
1158

×
1159
                // Validate batch when full.
×
1160
                if len(batch) >= int(cfg.MaxBatchSize) {
×
1161
                        err := validateBatch()
×
1162
                        if err != nil {
×
1163
                                return fmt.Errorf("batch validation failed: %w",
×
1164
                                        err)
×
1165
                        }
×
1166
                }
1167

1168
                s.Do(func() {
×
1169
                        elapsed := time.Since(t0).Seconds()
×
1170
                        ratePerSec := float64(chunk) / elapsed
×
1171
                        log.Debugf("Migrated %d closed scids "+
×
1172
                                "(%.2f entries/sec)", count, ratePerSec)
×
1173

×
1174
                        t0 = time.Now()
×
1175
                        chunk = 0
×
UNCOV
1176
                })
×
1177

1178
                return nil
×
1179
        }
1180

UNCOV
1181
        err := forEachClosedSCID(
×
UNCOV
1182
                kvBackend, migrateSingleClosedSCID, func() {
×
1183
                        count = 0
×
1184
                        chunk = 0
×
1185
                        t0 = time.Now()
×
1186
                        batch = make([][]byte, 0, cfg.MaxBatchSize)
×
1187
                },
×
1188
        )
1189
        if err != nil {
×
1190
                return fmt.Errorf("could not migrate closed SCID index: %w",
×
1191
                        err)
×
1192
        }
×
1193

1194
        // Validate any remaining SCIDs in the batch.
UNCOV
1195
        if len(batch) > 0 {
×
1196
                err := validateBatch()
×
1197
                if err != nil {
×
1198
                        return fmt.Errorf("final batch validation failed: %w",
×
1199
                                err)
×
UNCOV
1200
                }
×
1201
        }
1202

1203
        log.Infof("Migrated %d closed SCIDs from KV to SQL in %s", count,
×
1204
                time.Since(totalTime))
×
1205

×
1206
        return nil
×
1207
}
1208

1209
// migrateZombieIndex migrates the zombie index from the KV backend to the SQL
1210
// database. It collects zombie channels in batches, inserts them individually,
1211
// and validates them in batches.
1212
//
1213
// NOTE: before inserting an entry into the zombie index, the function checks
1214
// if the channel is already marked as closed in the SQL store. If it is,
1215
// the entry is skipped. This means that the resulting zombie index count in
1216
// the SQL store may well be less than the count of zombie channels in the KV
1217
// store.
1218
func migrateZombieIndex(ctx context.Context, cfg *sqldb.QueryConfig,
UNCOV
1219
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
1220

×
1221
        var (
×
1222
                totalTime = time.Now()
×
1223

×
1224
                count uint64
×
1225

×
1226
                t0    = time.Now()
×
1227
                chunk uint64
×
UNCOV
1228
                s     = rate.Sometimes{
×
UNCOV
1229
                        Interval: 10 * time.Second,
×
UNCOV
1230
                }
×
1231
        )
×
1232

×
1233
        type zombieEntry struct {
×
UNCOV
1234
                pub1 route.Vertex
×
UNCOV
1235
                pub2 route.Vertex
×
1236
        }
×
1237

×
1238
        batch := make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
1239

×
1240
        // validateBatch validates a batch of zombie SCIDs using batch query.
×
1241
        validateBatch := func() error {
×
1242
                if len(batch) == 0 {
×
1243
                        return nil
×
1244
                }
×
1245

1246
                scids := make([][]byte, 0, len(batch))
×
1247
                for scid := range batch {
×
1248
                        scids = append(scids, channelIDToBytes(scid))
×
1249
                }
×
1250

1251
                // Batch fetch all zombie channels from the database.
1252
                rows, err := sqlDB.GetZombieChannelsSCIDs(
×
1253
                        ctx, sqlc.GetZombieChannelsSCIDsParams{
×
1254
                                Version: int16(ProtocolV1),
×
UNCOV
1255
                                Scids:   scids,
×
1256
                        },
×
1257
                )
×
1258
                if err != nil {
×
1259
                        return fmt.Errorf("could not batch get zombie "+
×
1260
                                "SCIDs: %w", err)
×
1261
                }
×
1262

1263
                // Make sure that the number of rows returned matches
1264
                // the number of SCIDs we requested.
1265
                if len(rows) != len(scids) {
×
1266
                        return fmt.Errorf("expected to fetch %d zombie "+
×
1267
                                "SCIDs, but got %d", len(scids), len(rows))
×
1268
                }
×
1269

1270
                // Validate each row is in the batch.
UNCOV
1271
                for _, row := range rows {
×
UNCOV
1272
                        scid := byteOrder.Uint64(row.Scid)
×
1273

×
1274
                        kvdbZombie, ok := batch[scid]
×
1275
                        if !ok {
×
1276
                                return fmt.Errorf("zombie SCID %x not found "+
×
1277
                                        "in batch", scid)
×
1278
                        }
×
1279

1280
                        err = sqldb.CompareRecords(
×
1281
                                kvdbZombie.pub1[:], row.NodeKey1,
×
1282
                                fmt.Sprintf("zombie pub key 1 (%s) for "+
×
1283
                                        "channel %d", kvdbZombie.pub1, scid),
×
1284
                        )
×
UNCOV
1285
                        if err != nil {
×
UNCOV
1286
                                return err
×
1287
                        }
×
1288

1289
                        err = sqldb.CompareRecords(
×
1290
                                kvdbZombie.pub2[:], row.NodeKey2,
×
1291
                                fmt.Sprintf("zombie pub key 2 (%s) for "+
×
1292
                                        "channel %d", kvdbZombie.pub2, scid),
×
1293
                        )
×
1294
                        if err != nil {
×
1295
                                return err
×
UNCOV
1296
                        }
×
1297
                }
1298

1299
                // Reset the batch for the next iteration.
1300
                batch = make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
1301

×
1302
                return nil
×
1303
        }
1304

1305
        err := forEachZombieEntry(kvBackend, func(chanID uint64, pubKey1,
×
1306
                pubKey2 [33]byte) error {
×
UNCOV
1307

×
UNCOV
1308
                chanIDB := channelIDToBytes(chanID)
×
1309

×
1310
                // If it is in the closed SCID index, we don't need to
×
1311
                // add it to the zombie index.
×
1312
                //
×
1313
                // NOTE: this means that the resulting zombie index count in
×
1314
                // the SQL store may well be less than the count of zombie
×
UNCOV
1315
                // channels in the KV store.
×
UNCOV
1316
                isClosed, err := sqlDB.IsClosedChannel(ctx, chanIDB)
×
1317
                if err != nil {
×
1318
                        return fmt.Errorf("could not check closed "+
×
1319
                                "channel: %w", err)
×
1320
                }
×
UNCOV
1321
                if isClosed {
×
UNCOV
1322
                        return nil
×
UNCOV
1323
                }
×
1324

UNCOV
1325
                count++
×
1326
                chunk++
×
1327

×
1328
                err = sqlDB.UpsertZombieChannel(
×
1329
                        ctx, sqlc.UpsertZombieChannelParams{
×
1330
                                Version:  int16(ProtocolV1),
×
1331
                                Scid:     chanIDB,
×
1332
                                NodeKey1: pubKey1[:],
×
1333
                                NodeKey2: pubKey2[:],
×
1334
                        },
×
1335
                )
×
1336
                if err != nil {
×
UNCOV
1337
                        return fmt.Errorf("could not upsert zombie "+
×
1338
                                "channel %d: %w", chanID, err)
×
1339
                }
×
1340

1341
                // Add to validation batch only after successful insertion.
1342
                batch[chanID] = &zombieEntry{
×
1343
                        pub1: pubKey1,
×
1344
                        pub2: pubKey2,
×
UNCOV
1345
                }
×
UNCOV
1346

×
UNCOV
1347
                // Validate batch when full.
×
UNCOV
1348
                if len(batch) >= int(cfg.MaxBatchSize) {
×
UNCOV
1349
                        err := validateBatch()
×
UNCOV
1350
                        if err != nil {
×
1351
                                return fmt.Errorf("batch validation failed: %w",
×
1352
                                        err)
×
1353
                        }
×
1354
                }
1355

1356
                s.Do(func() {
×
1357
                        elapsed := time.Since(t0).Seconds()
×
UNCOV
1358
                        ratePerSec := float64(chunk) / elapsed
×
1359
                        log.Debugf("Migrated %d zombie index entries "+
×
1360
                                "(%.2f entries/sec)", count, ratePerSec)
×
1361

×
1362
                        t0 = time.Now()
×
1363
                        chunk = 0
×
UNCOV
1364
                })
×
1365

UNCOV
1366
                return nil
×
UNCOV
1367
        }, func() {
×
UNCOV
1368
                count = 0
×
UNCOV
1369
                chunk = 0
×
UNCOV
1370
                t0 = time.Now()
×
UNCOV
1371
                batch = make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
UNCOV
1372
        })
×
UNCOV
1373
        if err != nil {
×
1374
                return fmt.Errorf("could not migrate zombie index: %w", err)
×
1375
        }
×
1376

1377
        // Validate any remaining zombie SCIDs in the batch.
1378
        if len(batch) > 0 {
×
1379
                err := validateBatch()
×
1380
                if err != nil {
×
1381
                        return fmt.Errorf("final batch validation failed: %w",
×
1382
                                err)
×
1383
                }
×
1384
        }
1385

1386
        log.Infof("Migrated %d zombie channels from KV to SQL in %s", count,
×
UNCOV
1387
                time.Since(totalTime))
×
1388

×
1389
        return nil
×
1390
}
1391

1392
// forEachZombieEntry iterates over each zombie channel entry in the
1393
// KV backend and calls the provided callback function for each entry.
1394
func forEachZombieEntry(db kvdb.Backend, cb func(chanID uint64, pubKey1,
1395
        pubKey2 [33]byte) error, reset func()) error {
×
1396

×
1397
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
UNCOV
1398
                edges := tx.ReadBucket(edgeBucket)
×
UNCOV
1399
                if edges == nil {
×
1400
                        return ErrGraphNoEdgesFound
×
1401
                }
×
1402
                zombieIndex := edges.NestedReadBucket(zombieBucket)
×
1403
                if zombieIndex == nil {
×
1404
                        return nil
×
1405
                }
×
1406

1407
                return zombieIndex.ForEach(func(k, v []byte) error {
×
1408
                        var pubKey1, pubKey2 [33]byte
×
UNCOV
1409
                        copy(pubKey1[:], v[:33])
×
UNCOV
1410
                        copy(pubKey2[:], v[33:])
×
UNCOV
1411

×
1412
                        return cb(byteOrder.Uint64(k), pubKey1, pubKey2)
×
1413
                })
×
1414
        }, reset)
1415
}
1416

1417
// forEachClosedSCID iterates over each closed SCID in the KV backend and calls
1418
// the provided callback function for each SCID.
1419
func forEachClosedSCID(db kvdb.Backend,
1420
        cb func(lnwire.ShortChannelID) error, reset func()) error {
×
1421

×
1422
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
1423
                closedScids := tx.ReadBucket(closedScidBucket)
×
1424
                if closedScids == nil {
×
1425
                        return nil
×
1426
                }
×
1427

1428
                return closedScids.ForEach(func(k, _ []byte) error {
×
1429
                        return cb(lnwire.NewShortChanIDFromInt(
×
1430
                                byteOrder.Uint64(k),
×
1431
                        ))
×
1432
                })
×
1433
        }, reset)
1434
}
1435

1436
// insertNodeSQLMig inserts the node record into the database during the graph
1437
// SQL migration. No error is expected if the node already exists. Unlike the
1438
// main upsertNode function, this function does not require that a new node
1439
// update have a newer timestamp than the existing one. This is because we want
1440
// the migration to be idempotent and dont want to error out if we re-insert the
1441
// exact same node.
1442
func insertNodeSQLMig(ctx context.Context, db SQLQueries,
1443
        node *models.Node) (int64, error) {
×
UNCOV
1444

×
UNCOV
1445
        params := sqlc.InsertNodeMigParams{
×
1446
                Version: int16(ProtocolV1),
×
1447
                PubKey:  node.PubKeyBytes[:],
×
1448
        }
×
1449

×
1450
        if node.HaveNodeAnnouncement {
×
1451
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
1452
                params.Color = sqldb.SQLStrValid(EncodeHexColor(node.Color))
×
1453
                params.Alias = sqldb.SQLStrValid(node.Alias)
×
1454
                params.Signature = node.AuthSigBytes
×
1455
        }
×
1456

1457
        nodeID, err := db.InsertNodeMig(ctx, params)
×
UNCOV
1458
        if err != nil {
×
UNCOV
1459
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
1460
                        err)
×
UNCOV
1461
        }
×
1462

1463
        // We can exit here if we don't have the announcement yet.
UNCOV
1464
        if !node.HaveNodeAnnouncement {
×
UNCOV
1465
                return nodeID, nil
×
UNCOV
1466
        }
×
1467

1468
        // Insert the node's features.
UNCOV
1469
        for feature := range node.Features.Features() {
×
UNCOV
1470
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
UNCOV
1471
                        NodeID:     nodeID,
×
UNCOV
1472
                        FeatureBit: int32(feature),
×
UNCOV
1473
                })
×
1474
                if err != nil {
×
1475
                        return 0, fmt.Errorf("unable to insert node(%d) "+
×
1476
                                "feature(%v): %w", nodeID, feature, err)
×
1477
                }
×
1478
        }
1479

1480
        // Update the node's addresses.
1481
        newAddresses, err := collectAddressRecords(node.Addresses)
×
1482
        if err != nil {
×
1483
                return 0, err
×
1484
        }
×
1485

1486
        // Any remaining entries in newAddresses are new addresses that need to
1487
        // be added to the database for the first time.
1488
        for addrType, addrList := range newAddresses {
×
1489
                for position, addr := range addrList {
×
1490
                        err := db.UpsertNodeAddress(
×
UNCOV
1491
                                ctx, sqlc.UpsertNodeAddressParams{
×
1492
                                        NodeID:   nodeID,
×
1493
                                        Type:     int16(addrType),
×
1494
                                        Address:  addr,
×
1495
                                        Position: int32(position),
×
UNCOV
1496
                                },
×
1497
                        )
×
1498
                        if err != nil {
×
1499
                                return 0, fmt.Errorf("unable to insert "+
×
1500
                                        "node(%d) address(%v): %w", nodeID,
×
1501
                                        addr, err)
×
1502
                        }
×
1503
                }
1504
        }
1505

1506
        // Convert the flat extra opaque data into a map of TLV types to
1507
        // values.
1508
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
1509
        if err != nil {
×
1510
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
1511
                        err)
×
1512
        }
×
1513

1514
        // Insert the node's extra signed fields.
1515
        for tlvType, value := range extra {
×
UNCOV
1516
                err = db.UpsertNodeExtraType(
×
UNCOV
1517
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
1518
                                NodeID: nodeID,
×
1519
                                Type:   int64(tlvType),
×
1520
                                Value:  value,
×
1521
                        },
×
UNCOV
1522
                )
×
UNCOV
1523
                if err != nil {
×
1524
                        return 0, fmt.Errorf("unable to upsert node(%d) extra "+
×
1525
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1526
                }
×
1527
        }
1528

1529
        return nodeID, nil
×
1530
}
1531

1532
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
1533
// channel.
1534
type dbChanInfo struct {
1535
        channelID int64
1536
        node1ID   int64
1537
        node2ID   int64
1538
}
1539

1540
// insertChannelMig inserts a new channel record into the database during the
1541
// graph SQL migration.
1542
func insertChannelMig(ctx context.Context, db SQLQueries,
UNCOV
1543
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
1544

×
1545
        // Make sure that at least a "shell" entry for each node is present in
×
1546
        // the nodes table.
×
1547
        //
×
1548
        // NOTE: we need this even during the SQL migration where nodes are
×
1549
        // migrated first because there are cases were some nodes may have
×
1550
        // been skipped due to invalid TLV data.
×
1551
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
1552
        if err != nil {
×
1553
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
1554
        }
×
1555

1556
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
UNCOV
1557
        if err != nil {
×
UNCOV
1558
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
1559
        }
×
1560

1561
        var capacity sql.NullInt64
×
1562
        if edge.Capacity != 0 {
×
1563
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
UNCOV
1564
        }
×
1565

UNCOV
1566
        createParams := sqlc.InsertChannelMigParams{
×
UNCOV
1567
                Version:     int16(ProtocolV1),
×
UNCOV
1568
                Scid:        channelIDToBytes(edge.ChannelID),
×
UNCOV
1569
                NodeID1:     node1DBID,
×
1570
                NodeID2:     node2DBID,
×
1571
                Outpoint:    edge.ChannelPoint.String(),
×
1572
                Capacity:    capacity,
×
1573
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
1574
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
1575
        }
×
1576

×
1577
        if edge.AuthProof != nil {
×
UNCOV
1578
                proof := edge.AuthProof
×
1579

×
1580
                createParams.Node1Signature = proof.NodeSig1Bytes
×
1581
                createParams.Node2Signature = proof.NodeSig2Bytes
×
1582
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
1583
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
1584
        }
×
1585

1586
        // Insert the new channel record.
UNCOV
1587
        dbChanID, err := db.InsertChannelMig(ctx, createParams)
×
1588
        if err != nil {
×
1589
                return nil, err
×
1590
        }
×
1591

1592
        // Insert any channel features.
1593
        for feature := range edge.Features.Features() {
×
1594
                err = db.InsertChannelFeature(
×
1595
                        ctx, sqlc.InsertChannelFeatureParams{
×
1596
                                ChannelID:  dbChanID,
×
1597
                                FeatureBit: int32(feature),
×
1598
                        },
×
1599
                )
×
1600
                if err != nil {
×
1601
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
1602
                                "feature(%v): %w", dbChanID, feature, err)
×
1603
                }
×
1604
        }
1605

1606
        // Finally, insert any extra TLV fields in the channel announcement.
1607
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1608
        if err != nil {
×
1609
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
1610
                        "data: %w", err)
×
1611
        }
×
1612

1613
        for tlvType, value := range extra {
×
UNCOV
1614
                err := db.UpsertChannelExtraType(
×
UNCOV
1615
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
UNCOV
1616
                                ChannelID: dbChanID,
×
1617
                                Type:      int64(tlvType),
×
1618
                                Value:     value,
×
1619
                        },
×
1620
                )
×
1621
                if err != nil {
×
UNCOV
1622
                        return nil, fmt.Errorf("unable to upsert "+
×
UNCOV
1623
                                "channel(%d) extra signed field(%v): %w",
×
1624
                                edge.ChannelID, tlvType, err)
×
1625
                }
×
1626
        }
1627

1628
        return &dbChanInfo{
×
1629
                channelID: dbChanID,
×
1630
                node1ID:   node1DBID,
×
1631
                node2ID:   node2DBID,
×
1632
        }, nil
×
1633
}
1634

1635
// insertChanEdgePolicyMig inserts the channel policy info we have stored for
1636
// a channel we already know of. This is used during the SQL migration
1637
// process to insert channel policies.
1638
func insertChanEdgePolicyMig(ctx context.Context, tx SQLQueries,
1639
        dbChan *dbChanInfo, edge *models.ChannelEdgePolicy) error {
×
UNCOV
1640

×
UNCOV
1641
        // Figure out which node this edge is from.
×
UNCOV
1642
        isNode1 := edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
UNCOV
1643
        nodeID := dbChan.node1ID
×
UNCOV
1644
        if !isNode1 {
×
UNCOV
1645
                nodeID = dbChan.node2ID
×
UNCOV
1646
        }
×
1647

UNCOV
1648
        var (
×
UNCOV
1649
                inboundBase sql.NullInt64
×
UNCOV
1650
                inboundRate sql.NullInt64
×
UNCOV
1651
        )
×
UNCOV
1652
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
UNCOV
1653
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
UNCOV
1654
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
UNCOV
1655
        })
×
1656

UNCOV
1657
        id, err := tx.InsertEdgePolicyMig(ctx, sqlc.InsertEdgePolicyMigParams{
×
UNCOV
1658
                Version:     int16(ProtocolV1),
×
UNCOV
1659
                ChannelID:   dbChan.channelID,
×
UNCOV
1660
                NodeID:      nodeID,
×
UNCOV
1661
                Timelock:    int32(edge.TimeLockDelta),
×
UNCOV
1662
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
UNCOV
1663
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
UNCOV
1664
                MinHtlcMsat: int64(edge.MinHTLC),
×
UNCOV
1665
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
UNCOV
1666
                Disabled: sql.NullBool{
×
UNCOV
1667
                        Valid: true,
×
UNCOV
1668
                        Bool:  edge.IsDisabled(),
×
UNCOV
1669
                },
×
UNCOV
1670
                MaxHtlcMsat: sql.NullInt64{
×
UNCOV
1671
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
UNCOV
1672
                        Int64: int64(edge.MaxHTLC),
×
UNCOV
1673
                },
×
UNCOV
1674
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
UNCOV
1675
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
UNCOV
1676
                InboundBaseFeeMsat:      inboundBase,
×
UNCOV
1677
                InboundFeeRateMilliMsat: inboundRate,
×
UNCOV
1678
                Signature:               edge.SigBytes,
×
UNCOV
1679
        })
×
UNCOV
1680
        if err != nil {
×
UNCOV
1681
                return fmt.Errorf("unable to upsert edge policy: %w", err)
×
UNCOV
1682
        }
×
1683

1684
        // Convert the flat extra opaque data into a map of TLV types to
1685
        // values.
UNCOV
1686
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
UNCOV
1687
        if err != nil {
×
UNCOV
1688
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
UNCOV
1689
                        err)
×
UNCOV
1690
        }
×
1691

1692
        // Insert all new extra signed fields for the channel policy.
UNCOV
1693
        for tlvType, value := range extra {
×
UNCOV
1694
                err = tx.UpsertChanPolicyExtraType(
×
UNCOV
1695
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
UNCOV
1696
                                ChannelPolicyID: id,
×
UNCOV
1697
                                Type:            int64(tlvType),
×
UNCOV
1698
                                Value:           value,
×
UNCOV
1699
                        },
×
UNCOV
1700
                )
×
UNCOV
1701
                if err != nil {
×
UNCOV
1702
                        return fmt.Errorf("unable to insert "+
×
UNCOV
1703
                                "channel_policy(%d) extra signed field(%v): %w",
×
UNCOV
1704
                                id, tlvType, err)
×
UNCOV
1705
                }
×
1706
        }
1707

UNCOV
1708
        return nil
×
1709
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc