• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 19316446305

13 Nov 2025 12:34AM UTC coverage: 65.219% (+8.3%) from 56.89%
19316446305

push

github

web-flow
Merge pull request #10343 from lightningnetwork/0-21-0-staging

Merge branch `0-21-staging`

361 of 5339 new or added lines in 47 files covered. (6.76%)

34 existing lines in 8 files now uncovered.

137571 of 210938 relevant lines covered (65.22%)

20832.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/migration1/sql_migration.go
1
package migration1
2

3
import (
4
        "bytes"
5
        "cmp"
6
        "context"
7
        "database/sql"
8
        "errors"
9
        "fmt"
10
        "image/color"
11
        "net"
12
        "slices"
13
        "time"
14

15
        "github.com/btcsuite/btcd/chaincfg/chainhash"
16
        "github.com/lightningnetwork/lnd/graph/db/migration1/models"
17
        "github.com/lightningnetwork/lnd/graph/db/migration1/sqlc"
18
        "github.com/lightningnetwork/lnd/kvdb"
19
        "github.com/lightningnetwork/lnd/lnwire"
20
        "github.com/lightningnetwork/lnd/routing/route"
21
        "github.com/lightningnetwork/lnd/sqldb"
22
        "golang.org/x/time/rate"
23
)
24

25
// MigrateGraphToSQL migrates the graph store from a KV backend to a SQL
26
// backend.
27
//
28
// NOTE: this is currently not called from any code path. It is called via tests
29
// only for now and will be called from the main lnd binary once the
30
// migration is fully implemented and tested.
31
func MigrateGraphToSQL(ctx context.Context, cfg *SQLStoreConfig,
32
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
33

×
34
        log.Infof("Starting migration of the graph store from KV to SQL")
×
35
        t0 := time.Now()
×
36

×
37
        // Check if there is a graph to migrate.
×
38
        graphExists, err := checkGraphExists(kvBackend)
×
39
        if err != nil {
×
40
                return fmt.Errorf("failed to check graph existence: %w", err)
×
41
        }
×
42
        if !graphExists {
×
43
                log.Infof("No graph found in KV store, skipping the migration")
×
44
                return nil
×
45
        }
×
46

47
        // 1) Migrate all the nodes.
48
        err = migrateNodes(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
49
        if err != nil {
×
50
                return fmt.Errorf("could not migrate nodes: %w", err)
×
51
        }
×
52

53
        // 2) Migrate the source node.
54
        if err := migrateSourceNode(ctx, kvBackend, sqlDB); err != nil {
×
55
                return fmt.Errorf("could not migrate source node: %w", err)
×
56
        }
×
57

58
        // 3) Migrate all the channels and channel policies.
59
        err = migrateChannelsAndPolicies(ctx, cfg, kvBackend, sqlDB)
×
60
        if err != nil {
×
61
                return fmt.Errorf("could not migrate channels and policies: %w",
×
62
                        err)
×
63
        }
×
64

65
        // 4) Migrate the Prune log.
66
        err = migratePruneLog(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
67
        if err != nil {
×
68
                return fmt.Errorf("could not migrate prune log: %w", err)
×
69
        }
×
70

71
        // 5) Migrate the closed SCID index.
72
        err = migrateClosedSCIDIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
73
        if err != nil {
×
74
                return fmt.Errorf("could not migrate closed SCID index: %w",
×
75
                        err)
×
76
        }
×
77

78
        // 6) Migrate the zombie index.
79
        err = migrateZombieIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
80
        if err != nil {
×
81
                return fmt.Errorf("could not migrate zombie index: %w", err)
×
82
        }
×
83

84
        log.Infof("Finished migration of the graph store from KV to SQL in %v",
×
85
                time.Since(t0))
×
86

×
87
        return nil
×
88
}
89

90
// checkGraphExists checks if the graph exists in the KV backend.
91
func checkGraphExists(db kvdb.Backend) (bool, error) {
×
92
        // Check if there is even a graph to migrate.
×
93
        err := db.View(func(tx kvdb.RTx) error {
×
94
                // Check for the existence of the node bucket which is a top
×
95
                // level bucket that would have been created on the initial
×
96
                // creation of the graph store.
×
97
                nodes := tx.ReadBucket(nodeBucket)
×
98
                if nodes == nil {
×
99
                        return ErrGraphNotFound
×
100
                }
×
101

102
                return nil
×
103
        }, func() {})
×
104
        if errors.Is(err, ErrGraphNotFound) {
×
105
                return false, nil
×
106
        } else if err != nil {
×
107
                return false, err
×
108
        }
×
109

110
        return true, nil
×
111
}
112

113
// migrateNodes migrates all nodes from the KV backend to the SQL database.
114
// It collects nodes in batches, inserts them individually, and then validates
115
// them in batches.
116
func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
117
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
118

×
119
        // Keep track of the number of nodes migrated and the number of
×
120
        // nodes skipped due to errors.
×
121
        var (
×
122
                totalTime = time.Now()
×
123

×
124
                count   uint64
×
125
                skipped uint64
×
126

×
127
                t0    = time.Now()
×
128
                chunk uint64
×
129
                s     = rate.Sometimes{
×
130
                        Interval: 10 * time.Second,
×
131
                }
×
132
        )
×
133

×
134
        // batch is a map that holds node objects that have been migrated to
×
135
        // the native SQL store that have yet to be validated. The object's held
×
136
        // by this map were derived from the KVDB store and so when they are
×
137
        // validated, the map index (the SQL store node ID) will be used to
×
138
        // fetch the corresponding node object in the SQL store, and it will
×
139
        // then be compared against the original KVDB node object.
×
140
        batch := make(
×
141
                map[int64]*models.Node, cfg.MaxBatchSize,
×
142
        )
×
143

×
144
        // validateBatch validates that the batch of nodes in the 'batch' map
×
145
        // have been migrated successfully.
×
146
        validateBatch := func() error {
×
147
                if len(batch) == 0 {
×
148
                        return nil
×
149
                }
×
150

151
                // Extract DB node IDs.
152
                dbIDs := make([]int64, 0, len(batch))
×
153
                for dbID := range batch {
×
154
                        dbIDs = append(dbIDs, dbID)
×
155
                }
×
156

157
                // Batch fetch all nodes from the database.
158
                dbNodes, err := sqlDB.GetNodesByIDs(ctx, dbIDs)
×
159
                if err != nil {
×
160
                        return fmt.Errorf("could not batch fetch nodes: %w",
×
161
                                err)
×
162
                }
×
163

164
                // Make sure that the number of nodes fetched matches the number
165
                // of nodes in the batch.
166
                if len(dbNodes) != len(batch) {
×
167
                        return fmt.Errorf("expected to fetch %d nodes, "+
×
168
                                "but got %d", len(batch), len(dbNodes))
×
169
                }
×
170

171
                // Now, batch fetch the normalised data for all the nodes in
172
                // the batch.
173
                batchData, err := batchLoadNodeData(ctx, cfg, sqlDB, dbIDs)
×
174
                if err != nil {
×
175
                        return fmt.Errorf("unable to batch load node data: %w",
×
176
                                err)
×
177
                }
×
178

179
                for _, dbNode := range dbNodes {
×
180
                        // Get the KVDB node info from the batch map.
×
181
                        node, ok := batch[dbNode.ID]
×
182
                        if !ok {
×
183
                                return fmt.Errorf("node with ID %d not found "+
×
184
                                        "in batch", dbNode.ID)
×
185
                        }
×
186

187
                        // Build the migrated node from the DB node and the
188
                        // batch node data.
189
                        migNode, err := buildNodeWithBatchData(
×
190
                                dbNode, batchData,
×
191
                        )
×
192
                        if err != nil {
×
193
                                return fmt.Errorf("could not build migrated "+
×
194
                                        "node from dbNode(db id: %d, node "+
×
195
                                        "pub: %x): %w", dbNode.ID,
×
196
                                        node.PubKeyBytes, err)
×
197
                        }
×
198

199
                        // Make sure that the node addresses are sorted before
200
                        // comparing them to ensure that the order of addresses
201
                        // does not affect the comparison.
202
                        slices.SortFunc(
×
203
                                node.Addresses, func(i, j net.Addr) int {
×
204
                                        return cmp.Compare(
×
205
                                                i.String(), j.String(),
×
206
                                        )
×
207
                                },
×
208
                        )
209
                        slices.SortFunc(
×
210
                                migNode.Addresses, func(i, j net.Addr) int {
×
211
                                        return cmp.Compare(
×
212
                                                i.String(), j.String(),
×
213
                                        )
×
214
                                },
×
215
                        )
216

217
                        err = sqldb.CompareRecords(
×
218
                                node, migNode,
×
219
                                fmt.Sprintf("node %x", node.PubKeyBytes),
×
220
                        )
×
221
                        if err != nil {
×
222
                                return fmt.Errorf("node mismatch after "+
×
223
                                        "migration for node %x: %w",
×
224
                                        node.PubKeyBytes, err)
×
225
                        }
×
226
                }
227

228
                // Clear the batch map for the next iteration.
229
                batch = make(
×
230
                        map[int64]*models.Node, cfg.MaxBatchSize,
×
231
                )
×
232

×
233
                return nil
×
234
        }
235

236
        // Loop through each node in the KV store and insert it into the SQL
237
        // database.
238
        err := forEachNode(kvBackend, func(_ kvdb.RTx,
×
239
                node *models.Node) error {
×
240

×
241
                pub := node.PubKeyBytes
×
242

×
243
                // Sanity check to ensure that the node has valid extra opaque
×
244
                // data. If it does not, we'll skip it. We need to do this
×
245
                // because previously we would just persist any TLV bytes that
×
246
                // we received without validating them. Now, however, we
×
247
                // normalise the storage of extra opaque data, so we need to
×
248
                // ensure that the data is valid. We don't want to abort the
×
249
                // migration if we encounter a node with invalid extra opaque
×
250
                // data, so we'll just skip it and log a warning.
×
251
                _, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
252
                if errors.Is(err, ErrParsingExtraTLVBytes) {
×
253
                        skipped++
×
254
                        log.Warnf("Skipping migration of node %x with invalid "+
×
255
                                "extra opaque data: %v", pub,
×
256
                                node.ExtraOpaqueData)
×
257

×
258
                        return nil
×
259
                } else if err != nil {
×
260
                        return fmt.Errorf("unable to marshal extra "+
×
261
                                "opaque data for node %x: %w", pub, err)
×
262
                }
×
263

264
                if err = maybeOverrideNodeAddresses(node); err != nil {
×
265
                        skipped++
×
266
                        log.Warnf("Skipping migration of node %x with invalid "+
×
267
                                "address (%v): %v", pub, node.Addresses, err)
×
268

×
269
                        return nil
×
270
                }
×
271

272
                count++
×
273
                chunk++
×
274

×
275
                // Write the node to the SQL database.
×
276
                id, err := insertNodeSQLMig(ctx, sqlDB, node)
×
277
                if err != nil {
×
278
                        return fmt.Errorf("could not persist node(%x): %w", pub,
×
279
                                err)
×
280
                }
×
281

282
                // Add to validation batch.
283
                batch[id] = node
×
284

×
285
                // Validate batch when full.
×
286
                if len(batch) >= int(cfg.MaxBatchSize) {
×
287
                        err := validateBatch()
×
288
                        if err != nil {
×
289
                                return fmt.Errorf("batch validation failed: %w",
×
290
                                        err)
×
291
                        }
×
292
                }
293

294
                s.Do(func() {
×
295
                        elapsed := time.Since(t0).Seconds()
×
296
                        ratePerSec := float64(chunk) / elapsed
×
297
                        log.Debugf("Migrated %d nodes (%.2f nodes/sec)",
×
298
                                count, ratePerSec)
×
299

×
300
                        t0 = time.Now()
×
301
                        chunk = 0
×
302
                })
×
303

304
                return nil
×
305
        }, func() {
×
306
                count = 0
×
307
                chunk = 0
×
308
                skipped = 0
×
309
                t0 = time.Now()
×
310
                batch = make(map[int64]*models.Node, cfg.MaxBatchSize)
×
311
        })
×
312
        if err != nil {
×
313
                return fmt.Errorf("could not migrate nodes: %w", err)
×
314
        }
×
315

316
        // Validate any remaining nodes in the batch.
317
        if len(batch) > 0 {
×
318
                err := validateBatch()
×
319
                if err != nil {
×
320
                        return fmt.Errorf("final batch validation failed: %w",
×
321
                                err)
×
322
                }
×
323
        }
324

325
        log.Infof("Migrated %d nodes from KV to SQL in %v (skipped %d nodes "+
×
326
                "due to invalid TLV streams or invalid addresses)", count,
×
327
                time.Since(totalTime),
×
328

×
329
                skipped)
×
330

×
331
        return nil
×
332
}
333

334
// maybeOverrideNodeAddresses checks if the node has any opaque addresses that
335
// can be parsed. If so, it replaces the node's addresses with the parsed
336
// addresses. If the address is unparseable, it returns an error.
337
func maybeOverrideNodeAddresses(node *models.Node) error {
×
338
        // In the majority of cases, the number of node addresses will remain
×
339
        // unchanged, so we pre-allocate a slice of the same length.
×
340
        addrs := make([]net.Addr, 0, len(node.Addresses))
×
341

×
342
        // Iterate over each address in search of any opaque addresses that we
×
343
        // can inspect.
×
344
        for _, addr := range node.Addresses {
×
345
                opaque, ok := addr.(*lnwire.OpaqueAddrs)
×
346
                if !ok {
×
347
                        // Any non-opaque address is left unchanged.
×
348
                        addrs = append(addrs, addr)
×
349
                        continue
×
350
                }
351

352
                // For each opaque address, we'll now attempt to parse out any
353
                // known addresses. We'll do this in a loop, as it's possible
354
                // that there are several addresses encoded in a single opaque
355
                // address.
356
                payload := opaque.Payload
×
357
                for len(payload) > 0 {
×
358
                        var (
×
359
                                r            = bytes.NewReader(payload)
×
360
                                numAddrBytes = uint16(len(payload))
×
361
                        )
×
362
                        byteRead, readAddr, err := lnwire.ReadAddress(
×
363
                                r, numAddrBytes,
×
364
                        )
×
365
                        if err != nil {
×
366
                                return err
×
367
                        }
×
368

369
                        // If we were able to read an address, we'll add it to
370
                        // our list of addresses.
371
                        if readAddr != nil {
×
372
                                addrs = append(addrs, readAddr)
×
373
                        }
×
374

375
                        // If the address we read was an opaque address, it
376
                        // means we've hit an unknown address type, and it has
377
                        // consumed the rest of the payload. We can break out
378
                        // of the loop.
379
                        if _, ok := readAddr.(*lnwire.OpaqueAddrs); ok {
×
380
                                break
×
381
                        }
382

383
                        // If we've read all the bytes, we can also break.
384
                        if byteRead >= numAddrBytes {
×
385
                                break
×
386
                        }
387

388
                        // Otherwise, we'll advance our payload slice and
389
                        // continue.
390
                        payload = payload[byteRead:]
×
391
                }
392
        }
393

394
        // Override the node addresses if we have any.
395
        if len(addrs) != 0 {
×
396
                node.Addresses = addrs
×
397
        }
×
398

399
        return nil
×
400
}
401

402
// migrateSourceNode migrates the source node from the KV backend to the
403
// SQL database.
404
func migrateSourceNode(ctx context.Context, kvdb kvdb.Backend,
405
        sqlDB SQLQueries) error {
×
406

×
407
        log.Debugf("Migrating source node from KV to SQL")
×
408

×
409
        sourceNode, err := sourceNode(kvdb)
×
410
        if errors.Is(err, ErrSourceNodeNotSet) {
×
411
                // If the source node has not been set yet, we can skip this
×
412
                // migration step.
×
413
                return nil
×
414
        } else if err != nil {
×
415
                return fmt.Errorf("could not get source node from kv "+
×
416
                        "store: %w", err)
×
417
        }
×
418

419
        pub := sourceNode.PubKeyBytes
×
420

×
421
        // Get the DB ID of the source node by its public key. This node must
×
422
        // already exist in the SQL database, as it should have been migrated
×
423
        // in the previous node-migration step.
×
424
        id, err := sqlDB.GetNodeIDByPubKey(
×
425
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
426
                        PubKey:  pub[:],
×
NEW
427
                        Version: int16(lnwire.GossipVersion1),
×
428
                },
×
429
        )
×
430
        if err != nil {
×
431
                return fmt.Errorf("could not get source node ID: %w", err)
×
432
        }
×
433

434
        // Now we can add the source node to the SQL database.
435
        err = sqlDB.AddSourceNode(ctx, id)
×
436
        if err != nil {
×
437
                return fmt.Errorf("could not add source node to SQL store: %w",
×
438
                        err)
×
439
        }
×
440

441
        // Verify that the source node was added correctly by fetching it back
442
        // from the SQL database and checking that the expected DB ID and
443
        // pub key are returned. We don't need to do a whole node comparison
444
        // here, as this was already done in the previous migration step.
NEW
445
        srcNodes, err := sqlDB.GetSourceNodesByVersion(
×
NEW
446
                ctx, int16(lnwire.GossipVersion1),
×
NEW
447
        )
×
448
        if err != nil {
×
449
                return fmt.Errorf("could not get source nodes from SQL "+
×
450
                        "store: %w", err)
×
451
        }
×
452

453
        // The SQL store has support for multiple source nodes (for future
454
        // protocol versions) but this migration is purely aimed at the V1
455
        // store, and so we expect exactly one source node to be present.
456
        if len(srcNodes) != 1 {
×
457
                return fmt.Errorf("expected exactly one source node, "+
×
458
                        "got %d", len(srcNodes))
×
459
        }
×
460

461
        // Check that the source node ID and pub key match the original
462
        // source node.
463
        if srcNodes[0].NodeID != id {
×
464
                return fmt.Errorf("source node ID mismatch after migration: "+
×
465
                        "expected %d, got %d", id, srcNodes[0].NodeID)
×
466
        }
×
467
        err = sqldb.CompareRecords(pub[:], srcNodes[0].PubKey, "source node")
×
468
        if err != nil {
×
469
                return fmt.Errorf("source node pubkey mismatch after "+
×
470
                        "migration: %w", err)
×
471
        }
×
472

473
        log.Infof("Migrated source node with pubkey %x to SQL", pub[:])
×
474

×
475
        return nil
×
476
}
477

478
// migChanInfo holds the information about a channel and its policies.
479
type migChanInfo struct {
480
        // edge is the channel object as read from the KVDB source.
481
        edge *models.ChannelEdgeInfo
482

483
        // policy1 is the first channel policy for the channel as read from
484
        // the KVDB source.
485
        policy1 *models.ChannelEdgePolicy
486

487
        // policy2 is the second channel policy for the channel as read
488
        // from the KVDB source.
489
        policy2 *models.ChannelEdgePolicy
490

491
        // dbInfo holds location info (in the form of DB IDs) of the channel
492
        // and its policies in the native-SQL destination.
493
        dbInfo *dbChanInfo
494
}
495

496
// migrateChannelsAndPolicies migrates all channels and their policies
497
// from the KV backend to the SQL database.
498
func migrateChannelsAndPolicies(ctx context.Context, cfg *SQLStoreConfig,
499
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
500

×
501
        var (
×
502
                totalTime = time.Now()
×
503

×
504
                channelCount       uint64
×
505
                skippedChanCount   uint64
×
506
                policyCount        uint64
×
507
                skippedPolicyCount uint64
×
508

×
509
                t0    = time.Now()
×
510
                chunk uint64
×
511
                s     = rate.Sometimes{
×
512
                        Interval: 10 * time.Second,
×
513
                }
×
514
        )
×
515
        migChanPolicy := func(dbChanInfo *dbChanInfo,
×
516
                policy *models.ChannelEdgePolicy) error {
×
517

×
518
                // If the policy is nil, we can skip it.
×
519
                if policy == nil {
×
520
                        return nil
×
521
                }
×
522

523
                // Unlike the special case of invalid TLV bytes for node and
524
                // channel announcements, we don't need to handle the case for
525
                // channel policies here because it is already handled in the
526
                // `forEachChannel` function. If the policy has invalid TLV
527
                // bytes, then `nil` will be passed to this function.
528

529
                policyCount++
×
530

×
531
                err := insertChanEdgePolicyMig(ctx, sqlDB, dbChanInfo, policy)
×
532
                if err != nil {
×
533
                        return fmt.Errorf("could not migrate channel "+
×
534
                                "policy %d: %w", policy.ChannelID, err)
×
535
                }
×
536

537
                return nil
×
538
        }
539

540
        // batch is used to collect migrated channel info that we will
541
        // batch-validate. Each entry is indexed by the DB ID of the channel
542
        // in the SQL database.
543
        batch := make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
544

×
545
        // Iterate over each channel in the KV store and migrate it and its
×
546
        // policies to the SQL database.
×
547
        err := forEachChannel(kvBackend, func(channel *models.ChannelEdgeInfo,
×
548
                policy1 *models.ChannelEdgePolicy,
×
549
                policy2 *models.ChannelEdgePolicy) error {
×
550

×
551
                scid := channel.ChannelID
×
552

×
553
                // Here, we do a sanity check to ensure that the chain hash of
×
554
                // the channel returned by the KV store matches the expected
×
555
                // chain hash. This is important since in the SQL store, we will
×
556
                // no longer explicitly store the chain hash in the channel
×
557
                // info, but rather rely on the chain hash LND is running with.
×
558
                // So this is our way of ensuring that LND is running on the
×
559
                // correct network at migration time.
×
560
                if channel.ChainHash != cfg.ChainHash {
×
561
                        return fmt.Errorf("channel %d has chain hash %s, "+
×
562
                                "expected %s", scid, channel.ChainHash,
×
563
                                cfg.ChainHash)
×
564
                }
×
565

566
                // Sanity check to ensure that the channel has valid extra
567
                // opaque data. If it does not, we'll skip it. We need to do
568
                // this because previously we would just persist any TLV bytes
569
                // that we received without validating them. Now, however, we
570
                // normalise the storage of extra opaque data, so we need to
571
                // ensure that the data is valid. We don't want to abort the
572
                // migration if we encounter a channel with invalid extra opaque
573
                // data, so we'll just skip it and log a warning.
574
                _, err := marshalExtraOpaqueData(channel.ExtraOpaqueData)
×
575
                if errors.Is(err, ErrParsingExtraTLVBytes) {
×
576
                        log.Warnf("Skipping channel %d with invalid "+
×
577
                                "extra opaque data: %v", scid,
×
578
                                channel.ExtraOpaqueData)
×
579

×
580
                        skippedChanCount++
×
581

×
582
                        // If we skip a channel, we also skip its policies.
×
583
                        if policy1 != nil {
×
584
                                skippedPolicyCount++
×
585
                        }
×
586
                        if policy2 != nil {
×
587
                                skippedPolicyCount++
×
588
                        }
×
589

590
                        return nil
×
591
                } else if err != nil {
×
592
                        return fmt.Errorf("unable to marshal extra opaque "+
×
593
                                "data for channel %d (%v): %w", scid,
×
594
                                channel.ExtraOpaqueData, err)
×
595
                }
×
596

597
                channelCount++
×
598
                chunk++
×
599

×
600
                // Migrate the channel info along with its policies.
×
601
                dbChanInfo, err := insertChannelMig(ctx, sqlDB, channel)
×
602
                if err != nil {
×
603
                        return fmt.Errorf("could not insert record for "+
×
604
                                "channel %d in SQL store: %w", scid, err)
×
605
                }
×
606

607
                // Now, migrate the two channel policies for the channel.
608
                err = migChanPolicy(dbChanInfo, policy1)
×
609
                if err != nil {
×
610
                        return fmt.Errorf("could not migrate policy1(%d): %w",
×
611
                                scid, err)
×
612
                }
×
613
                err = migChanPolicy(dbChanInfo, policy2)
×
614
                if err != nil {
×
615
                        return fmt.Errorf("could not migrate policy2(%d): %w",
×
616
                                scid, err)
×
617
                }
×
618

619
                // Collect the migrated channel info and policies in a batch for
620
                // later validation.
621
                batch[dbChanInfo.channelID] = &migChanInfo{
×
622
                        edge:    channel,
×
623
                        policy1: policy1,
×
624
                        policy2: policy2,
×
625
                        dbInfo:  dbChanInfo,
×
626
                }
×
627

×
628
                if len(batch) >= int(cfg.QueryCfg.MaxBatchSize) {
×
629
                        // Do batch validation.
×
630
                        err := validateMigratedChannels(ctx, cfg, sqlDB, batch)
×
631
                        if err != nil {
×
632
                                return fmt.Errorf("could not validate "+
×
633
                                        "channel batch: %w", err)
×
634
                        }
×
635

636
                        batch = make(
×
637
                                map[int64]*migChanInfo,
×
638
                                cfg.QueryCfg.MaxBatchSize,
×
639
                        )
×
640
                }
641

642
                s.Do(func() {
×
643
                        elapsed := time.Since(t0).Seconds()
×
644
                        ratePerSec := float64(chunk) / elapsed
×
645
                        log.Debugf("Migrated %d channels (%.2f channels/sec)",
×
646
                                channelCount, ratePerSec)
×
647

×
648
                        t0 = time.Now()
×
649
                        chunk = 0
×
650
                })
×
651

652
                return nil
×
653
        }, func() {
×
654
                channelCount = 0
×
655
                policyCount = 0
×
656
                chunk = 0
×
657
                skippedChanCount = 0
×
658
                skippedPolicyCount = 0
×
659
                t0 = time.Now()
×
660
                batch = make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
661
        })
×
662
        if err != nil {
×
663
                return fmt.Errorf("could not migrate channels and policies: %w",
×
664
                        err)
×
665
        }
×
666

667
        if len(batch) > 0 {
×
668
                // Do a final batch validation for any remaining channels.
×
669
                err := validateMigratedChannels(ctx, cfg, sqlDB, batch)
×
670
                if err != nil {
×
671
                        return fmt.Errorf("could not validate final channel "+
×
672
                                "batch: %w", err)
×
673
                }
×
674

675
                batch = make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
676
        }
677

678
        log.Infof("Migrated %d channels and %d policies from KV to SQL in %s"+
×
679
                "(skipped %d channels and %d policies due to invalid TLV "+
×
680
                "streams)", channelCount, policyCount, time.Since(totalTime),
×
681
                skippedChanCount, skippedPolicyCount)
×
682

×
683
        return nil
×
684
}
685

686
// validateMigratedChannels validates the channels in the batch after they have
687
// been migrated to the SQL database. It batch fetches all channels by their IDs
688
// and compares the migrated channels and their policies with the original ones
689
// to ensure they match using batch construction patterns.
690
func validateMigratedChannels(ctx context.Context, cfg *SQLStoreConfig,
691
        sqlDB SQLQueries, batch map[int64]*migChanInfo) error {
×
692

×
693
        // Convert batch keys (DB IDs) to an int slice for the batch query.
×
694
        dbChanIDs := make([]int64, 0, len(batch))
×
695
        for id := range batch {
×
696
                dbChanIDs = append(dbChanIDs, id)
×
697
        }
×
698

699
        // Batch fetch all channels with their policies.
700
        rows, err := sqlDB.GetChannelsByIDs(ctx, dbChanIDs)
×
701
        if err != nil {
×
702
                return fmt.Errorf("could not batch get channels by IDs: %w",
×
703
                        err)
×
704
        }
×
705

706
        // Sanity check that the same number of channels were returned
707
        // as requested.
708
        if len(rows) != len(dbChanIDs) {
×
709
                return fmt.Errorf("expected to fetch %d channels, "+
×
710
                        "but got %d", len(dbChanIDs), len(rows))
×
711
        }
×
712

713
        // Collect all policy IDs needed for batch data loading.
714
        dbPolicyIDs := make([]int64, 0, len(dbChanIDs)*2)
×
715

×
716
        for _, row := range rows {
×
717
                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
718

×
719
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
720
                if err != nil {
×
721
                        return fmt.Errorf("could not extract channel policies"+
×
722
                                " for SCID %d: %w", scid, err)
×
723
                }
×
724
                if dbPol1 != nil {
×
725
                        dbPolicyIDs = append(dbPolicyIDs, dbPol1.ID)
×
726
                }
×
727
                if dbPol2 != nil {
×
728
                        dbPolicyIDs = append(dbPolicyIDs, dbPol2.ID)
×
729
                }
×
730
        }
731

732
        // Batch load all channel and policy data (features, extras).
733
        batchData, err := batchLoadChannelData(
×
734
                ctx, cfg.QueryCfg, sqlDB, dbChanIDs, dbPolicyIDs,
×
735
        )
×
736
        if err != nil {
×
737
                return fmt.Errorf("could not batch load channel and policy "+
×
738
                        "data: %w", err)
×
739
        }
×
740

741
        // Validate each channel in the batch using pre-loaded data.
742
        for _, row := range rows {
×
743
                kvdbChan, ok := batch[row.GraphChannel.ID]
×
744
                if !ok {
×
745
                        return fmt.Errorf("channel with ID %d not found "+
×
746
                                "in batch", row.GraphChannel.ID)
×
747
                }
×
748

749
                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
750

×
751
                err = validateMigratedChannelWithBatchData(
×
752
                        cfg, scid, kvdbChan, row, batchData,
×
753
                )
×
754
                if err != nil {
×
755
                        return fmt.Errorf("channel %d validation failed "+
×
756
                                "after migration: %w", scid, err)
×
757
                }
×
758
        }
759

760
        return nil
×
761
}
762

763
// validateMigratedChannelWithBatchData validates a single migrated channel
764
// using pre-fetched batch data for optimal performance.
765
func validateMigratedChannelWithBatchData(cfg *SQLStoreConfig,
766
        scid uint64, info *migChanInfo, row sqlc.GetChannelsByIDsRow,
767
        batchData *batchChannelData) error {
×
768

×
769
        dbChanInfo := info.dbInfo
×
770
        channel := info.edge
×
771

×
772
        // Assert that the DB IDs for the channel and nodes are as expected
×
773
        // given the inserted channel info.
×
774
        err := sqldb.CompareRecords(
×
775
                dbChanInfo.channelID, row.GraphChannel.ID, "channel DB ID",
×
776
        )
×
777
        if err != nil {
×
778
                return err
×
779
        }
×
780
        err = sqldb.CompareRecords(
×
781
                dbChanInfo.node1ID, row.Node1ID, "node1 DB ID",
×
782
        )
×
783
        if err != nil {
×
784
                return err
×
785
        }
×
786
        err = sqldb.CompareRecords(
×
787
                dbChanInfo.node2ID, row.Node2ID, "node2 DB ID",
×
788
        )
×
789
        if err != nil {
×
790
                return err
×
791
        }
×
792

793
        // Build node vertices from the row data.
794
        node1, node2, err := buildNodeVertices(
×
795
                row.Node1PubKey, row.Node2PubKey,
×
796
        )
×
797
        if err != nil {
×
798
                return err
×
799
        }
×
800

801
        // Build channel info using batch data.
802
        migChan, err := buildEdgeInfoWithBatchData(
×
803
                cfg.ChainHash, row.GraphChannel, node1, node2, batchData,
×
804
        )
×
805
        if err != nil {
×
806
                return fmt.Errorf("could not build migrated channel info: %w",
×
807
                        err)
×
808
        }
×
809

810
        // Extract channel policies from the row.
811
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
812
        if err != nil {
×
813
                return fmt.Errorf("could not extract channel policies: %w", err)
×
814
        }
×
815

816
        // Build channel policies using batch data.
817
        migPol1, migPol2, err := buildChanPoliciesWithBatchData(
×
818
                dbPol1, dbPol2, scid, node1, node2, batchData,
×
819
        )
×
820
        if err != nil {
×
821
                return fmt.Errorf("could not build migrated channel "+
×
822
                        "policies: %w", err)
×
823
        }
×
824

825
        // Finally, compare the original channel info and
826
        // policies with the migrated ones to ensure they match.
827
        if len(channel.ExtraOpaqueData) == 0 {
×
828
                channel.ExtraOpaqueData = nil
×
829
        }
×
830
        if len(migChan.ExtraOpaqueData) == 0 {
×
831
                migChan.ExtraOpaqueData = nil
×
832
        }
×
833

834
        err = sqldb.CompareRecords(
×
835
                channel, migChan, fmt.Sprintf("channel %d", scid),
×
836
        )
×
837
        if err != nil {
×
838
                return err
×
839
        }
×
840

841
        checkPolicy := func(expPolicy,
×
842
                migPolicy *models.ChannelEdgePolicy) error {
×
843

×
844
                switch {
×
845
                // Both policies are nil, nothing to compare.
846
                case expPolicy == nil && migPolicy == nil:
×
847
                        return nil
×
848

849
                // One of the policies is nil, but the other is not.
850
                case expPolicy == nil || migPolicy == nil:
×
851
                        return fmt.Errorf("expected both policies to be "+
×
852
                                "non-nil. Got expPolicy: %v, "+
×
853
                                "migPolicy: %v", expPolicy, migPolicy)
×
854

855
                // Both policies are non-nil, we can compare them.
856
                default:
×
857
                }
858

859
                if len(expPolicy.ExtraOpaqueData) == 0 {
×
860
                        expPolicy.ExtraOpaqueData = nil
×
861
                }
×
862
                if len(migPolicy.ExtraOpaqueData) == 0 {
×
863
                        migPolicy.ExtraOpaqueData = nil
×
864
                }
×
865

866
                return sqldb.CompareRecords(
×
867
                        *expPolicy, *migPolicy, "channel policy",
×
868
                )
×
869
        }
870

871
        err = checkPolicy(info.policy1, migPol1)
×
872
        if err != nil {
×
873
                return fmt.Errorf("policy1 mismatch for channel %d: %w", scid,
×
874
                        err)
×
875
        }
×
876

877
        err = checkPolicy(info.policy2, migPol2)
×
878
        if err != nil {
×
879
                return fmt.Errorf("policy2 mismatch for channel %d: %w", scid,
×
880
                        err)
×
881
        }
×
882

883
        return nil
×
884
}
885

886
// migratePruneLog migrates the prune log from the KV backend to the SQL
887
// database. It collects entries in batches, inserts them individually, and then
888
// validates them in batches using GetPruneEntriesForHeights for better i
889
// performance.
890
func migratePruneLog(ctx context.Context, cfg *sqldb.QueryConfig,
891
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
892

×
893
        var (
×
894
                totalTime = time.Now()
×
895

×
896
                count          uint64
×
897
                pruneTipHeight uint32
×
898
                pruneTipHash   chainhash.Hash
×
899

×
900
                t0    = time.Now()
×
901
                chunk uint64
×
902
                s     = rate.Sometimes{
×
903
                        Interval: 10 * time.Second,
×
904
                }
×
905
        )
×
906

×
907
        batch := make(map[uint32]chainhash.Hash, cfg.MaxBatchSize)
×
908

×
909
        // validateBatch validates a batch of prune entries using batch query.
×
910
        validateBatch := func() error {
×
911
                if len(batch) == 0 {
×
912
                        return nil
×
913
                }
×
914

915
                // Extract heights for the batch query.
916
                heights := make([]int64, 0, len(batch))
×
917
                for height := range batch {
×
918
                        heights = append(heights, int64(height))
×
919
                }
×
920

921
                // Batch fetch all entries from the database.
922
                rows, err := sqlDB.GetPruneEntriesForHeights(ctx, heights)
×
923
                if err != nil {
×
924
                        return fmt.Errorf("could not batch get prune "+
×
925
                                "entries: %w", err)
×
926
                }
×
927

928
                if len(rows) != len(batch) {
×
929
                        return fmt.Errorf("expected to fetch %d prune "+
×
930
                                "entries, but got %d", len(batch),
×
931
                                len(rows))
×
932
                }
×
933

934
                // Validate each entry in the batch.
935
                for _, row := range rows {
×
936
                        kvdbHash, ok := batch[uint32(row.BlockHeight)]
×
937
                        if !ok {
×
938
                                return fmt.Errorf("prune entry for height %d "+
×
939
                                        "not found in batch", row.BlockHeight)
×
940
                        }
×
941

942
                        err := sqldb.CompareRecords(
×
943
                                kvdbHash[:], row.BlockHash,
×
944
                                fmt.Sprintf("prune log entry at height %d",
×
945
                                        row.BlockHash),
×
946
                        )
×
947
                        if err != nil {
×
948
                                return err
×
949
                        }
×
950
                }
951

952
                // Reset the batch map for the next iteration.
953
                batch = make(map[uint32]chainhash.Hash, cfg.MaxBatchSize)
×
954

×
955
                return nil
×
956
        }
957

958
        // Iterate over each prune log entry in the KV store and migrate it to
959
        // the SQL database.
960
        err := forEachPruneLogEntry(
×
961
                kvBackend, func(height uint32, hash *chainhash.Hash) error {
×
962
                        count++
×
963
                        chunk++
×
964

×
965
                        // Keep track of the prune tip height and hash.
×
966
                        if height > pruneTipHeight {
×
967
                                pruneTipHeight = height
×
968
                                pruneTipHash = *hash
×
969
                        }
×
970

971
                        // Insert the entry (individual inserts for now).
972
                        err := sqlDB.UpsertPruneLogEntry(
×
973
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
974
                                        BlockHeight: int64(height),
×
975
                                        BlockHash:   hash[:],
×
976
                                },
×
977
                        )
×
978
                        if err != nil {
×
979
                                return fmt.Errorf("unable to insert prune log "+
×
980
                                        "entry for height %d: %w", height, err)
×
981
                        }
×
982

983
                        // Add to validation batch.
984
                        batch[height] = *hash
×
985

×
986
                        // Validate batch when full.
×
987
                        if len(batch) >= int(cfg.MaxBatchSize) {
×
988
                                err := validateBatch()
×
989
                                if err != nil {
×
990
                                        return fmt.Errorf("batch "+
×
991
                                                "validation failed: %w", err)
×
992
                                }
×
993
                        }
994

995
                        s.Do(func() {
×
996
                                elapsed := time.Since(t0).Seconds()
×
997
                                ratePerSec := float64(chunk) / elapsed
×
998
                                log.Debugf("Migrated %d prune log "+
×
999
                                        "entries (%.2f entries/sec)",
×
1000
                                        count, ratePerSec)
×
1001

×
1002
                                t0 = time.Now()
×
1003
                                chunk = 0
×
1004
                        })
×
1005

1006
                        return nil
×
1007
                },
1008
                func() {
×
1009
                        count = 0
×
1010
                        chunk = 0
×
1011
                        t0 = time.Now()
×
1012
                        batch = make(
×
1013
                                map[uint32]chainhash.Hash, cfg.MaxBatchSize,
×
1014
                        )
×
1015
                },
×
1016
        )
1017
        if err != nil {
×
1018
                return fmt.Errorf("could not migrate prune log: %w", err)
×
1019
        }
×
1020

1021
        // Validate any remaining entries in the batch.
1022
        if len(batch) > 0 {
×
1023
                err := validateBatch()
×
1024
                if err != nil {
×
1025
                        return fmt.Errorf("final batch validation failed: %w",
×
1026
                                err)
×
1027
                }
×
1028
        }
1029

1030
        // Check that the prune tip is set correctly in the SQL
1031
        // database.
1032
        pruneTip, err := sqlDB.GetPruneTip(ctx)
×
1033
        if errors.Is(err, sql.ErrNoRows) {
×
1034
                // The ErrGraphNeverPruned error is expected if no prune log
×
1035
                // entries were migrated from the kvdb store. Otherwise, it's
×
1036
                // an unexpected error.
×
1037
                if count == 0 {
×
1038
                        log.Infof("No prune log entries found in KV store " +
×
1039
                                "to migrate")
×
1040
                        return nil
×
1041
                }
×
1042
                // Fall-through to the next error check.
1043
        }
1044
        if err != nil {
×
1045
                return fmt.Errorf("could not get prune tip: %w", err)
×
1046
        }
×
1047

1048
        if pruneTip.BlockHeight != int64(pruneTipHeight) ||
×
1049
                !bytes.Equal(pruneTip.BlockHash, pruneTipHash[:]) {
×
1050

×
1051
                return fmt.Errorf("prune tip mismatch after migration: "+
×
1052
                        "expected height %d, hash %s; got height %d, "+
×
1053
                        "hash %s", pruneTipHeight, pruneTipHash,
×
1054
                        pruneTip.BlockHeight,
×
1055
                        chainhash.Hash(pruneTip.BlockHash))
×
1056
        }
×
1057

1058
        log.Infof("Migrated %d prune log entries from KV to SQL in %s. "+
×
1059
                "The prune tip is: height %d, hash: %s", count,
×
1060
                time.Since(totalTime), pruneTipHeight, pruneTipHash)
×
1061

×
1062
        return nil
×
1063
}
1064

1065
// forEachPruneLogEntry iterates over each prune log entry in the KV
1066
// backend and calls the provided callback function for each entry.
1067
func forEachPruneLogEntry(db kvdb.Backend, cb func(height uint32,
1068
        hash *chainhash.Hash) error, reset func()) error {
×
1069

×
1070
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
1071
                metaBucket := tx.ReadBucket(graphMetaBucket)
×
1072
                if metaBucket == nil {
×
1073
                        return ErrGraphNotFound
×
1074
                }
×
1075

1076
                pruneBucket := metaBucket.NestedReadBucket(pruneLogBucket)
×
1077
                if pruneBucket == nil {
×
1078
                        // The graph has never been pruned and so, there are no
×
1079
                        // entries to iterate over.
×
1080
                        return nil
×
1081
                }
×
1082

1083
                return pruneBucket.ForEach(func(k, v []byte) error {
×
1084
                        blockHeight := byteOrder.Uint32(k)
×
1085
                        var blockHash chainhash.Hash
×
1086
                        copy(blockHash[:], v)
×
1087

×
1088
                        return cb(blockHeight, &blockHash)
×
1089
                })
×
1090
        }, reset)
1091
}
1092

1093
// migrateClosedSCIDIndex migrates the closed SCID index from the KV backend to
1094
// the SQL database. It collects SCIDs in batches, inserts them individually,
1095
// and then validates them in batches using GetClosedChannelsSCIDs for better
1096
// performance.
1097
func migrateClosedSCIDIndex(ctx context.Context, cfg *sqldb.QueryConfig,
1098
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
1099

×
1100
        var (
×
1101
                totalTime = time.Now()
×
1102

×
1103
                count uint64
×
1104

×
1105
                t0    = time.Now()
×
1106
                chunk uint64
×
1107
                s     = rate.Sometimes{
×
1108
                        Interval: 10 * time.Second,
×
1109
                }
×
1110
        )
×
1111

×
1112
        batch := make([][]byte, 0, cfg.MaxBatchSize)
×
1113

×
1114
        // validateBatch validates a batch of closed SCIDs using batch query.
×
1115
        validateBatch := func() error {
×
1116
                if len(batch) == 0 {
×
1117
                        return nil
×
1118
                }
×
1119

1120
                // Batch fetch all closed SCIDs from the database.
1121
                dbSCIDs, err := sqlDB.GetClosedChannelsSCIDs(ctx, batch)
×
1122
                if err != nil {
×
1123
                        return fmt.Errorf("could not batch get closed "+
×
1124
                                "SCIDs: %w", err)
×
1125
                }
×
1126

1127
                // Create set of SCIDs that exist in the database for quick
1128
                // lookup.
1129
                dbSCIDSet := make(map[string]struct{})
×
1130
                for _, scid := range dbSCIDs {
×
1131
                        dbSCIDSet[string(scid)] = struct{}{}
×
1132
                }
×
1133

1134
                // Validate each SCID in the batch.
1135
                for _, expectedSCID := range batch {
×
1136
                        if _, found := dbSCIDSet[string(expectedSCID)]; !found {
×
1137
                                return fmt.Errorf("closed SCID %x not found "+
×
1138
                                        "in database", expectedSCID)
×
1139
                        }
×
1140
                }
1141

1142
                // Reset the batch for the next iteration.
1143
                batch = make([][]byte, 0, cfg.MaxBatchSize)
×
1144

×
1145
                return nil
×
1146
        }
1147

1148
        migrateSingleClosedSCID := func(scid lnwire.ShortChannelID) error {
×
1149
                count++
×
1150
                chunk++
×
1151

×
1152
                chanIDB := channelIDToBytes(scid.ToUint64())
×
1153
                err := sqlDB.InsertClosedChannel(ctx, chanIDB)
×
1154
                if err != nil {
×
1155
                        return fmt.Errorf("could not insert closed channel "+
×
1156
                                "with SCID %s: %w", scid, err)
×
1157
                }
×
1158

1159
                // Add to validation batch.
1160
                batch = append(batch, chanIDB)
×
1161

×
1162
                // Validate batch when full.
×
1163
                if len(batch) >= int(cfg.MaxBatchSize) {
×
1164
                        err := validateBatch()
×
1165
                        if err != nil {
×
1166
                                return fmt.Errorf("batch validation failed: %w",
×
1167
                                        err)
×
1168
                        }
×
1169
                }
1170

1171
                s.Do(func() {
×
1172
                        elapsed := time.Since(t0).Seconds()
×
1173
                        ratePerSec := float64(chunk) / elapsed
×
1174
                        log.Debugf("Migrated %d closed scids "+
×
1175
                                "(%.2f entries/sec)", count, ratePerSec)
×
1176

×
1177
                        t0 = time.Now()
×
1178
                        chunk = 0
×
1179
                })
×
1180

1181
                return nil
×
1182
        }
1183

1184
        err := forEachClosedSCID(
×
1185
                kvBackend, migrateSingleClosedSCID, func() {
×
1186
                        count = 0
×
1187
                        chunk = 0
×
1188
                        t0 = time.Now()
×
1189
                        batch = make([][]byte, 0, cfg.MaxBatchSize)
×
1190
                },
×
1191
        )
1192
        if err != nil {
×
1193
                return fmt.Errorf("could not migrate closed SCID index: %w",
×
1194
                        err)
×
1195
        }
×
1196

1197
        // Validate any remaining SCIDs in the batch.
1198
        if len(batch) > 0 {
×
1199
                err := validateBatch()
×
1200
                if err != nil {
×
1201
                        return fmt.Errorf("final batch validation failed: %w",
×
1202
                                err)
×
1203
                }
×
1204
        }
1205

1206
        log.Infof("Migrated %d closed SCIDs from KV to SQL in %s", count,
×
1207
                time.Since(totalTime))
×
1208

×
1209
        return nil
×
1210
}
1211

1212
// migrateZombieIndex migrates the zombie index from the KV backend to the SQL
1213
// database. It collects zombie channels in batches, inserts them individually,
1214
// and validates them in batches.
1215
//
1216
// NOTE: before inserting an entry into the zombie index, the function checks
1217
// if the channel is already marked as closed in the SQL store. If it is,
1218
// the entry is skipped. This means that the resulting zombie index count in
1219
// the SQL store may well be less than the count of zombie channels in the KV
1220
// store.
1221
func migrateZombieIndex(ctx context.Context, cfg *sqldb.QueryConfig,
1222
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
1223

×
1224
        var (
×
1225
                totalTime = time.Now()
×
1226

×
1227
                count uint64
×
1228

×
1229
                t0    = time.Now()
×
1230
                chunk uint64
×
1231
                s     = rate.Sometimes{
×
1232
                        Interval: 10 * time.Second,
×
1233
                }
×
1234
        )
×
1235

×
1236
        type zombieEntry struct {
×
1237
                pub1 route.Vertex
×
1238
                pub2 route.Vertex
×
1239
        }
×
1240

×
1241
        batch := make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
1242

×
1243
        // validateBatch validates a batch of zombie SCIDs using batch query.
×
1244
        validateBatch := func() error {
×
1245
                if len(batch) == 0 {
×
1246
                        return nil
×
1247
                }
×
1248

1249
                scids := make([][]byte, 0, len(batch))
×
1250
                for scid := range batch {
×
1251
                        scids = append(scids, channelIDToBytes(scid))
×
1252
                }
×
1253

1254
                // Batch fetch all zombie channels from the database.
1255
                rows, err := sqlDB.GetZombieChannelsSCIDs(
×
1256
                        ctx, sqlc.GetZombieChannelsSCIDsParams{
×
NEW
1257
                                Version: int16(lnwire.GossipVersion1),
×
1258
                                Scids:   scids,
×
1259
                        },
×
1260
                )
×
1261
                if err != nil {
×
1262
                        return fmt.Errorf("could not batch get zombie "+
×
1263
                                "SCIDs: %w", err)
×
1264
                }
×
1265

1266
                // Make sure that the number of rows returned matches
1267
                // the number of SCIDs we requested.
1268
                if len(rows) != len(scids) {
×
1269
                        return fmt.Errorf("expected to fetch %d zombie "+
×
1270
                                "SCIDs, but got %d", len(scids), len(rows))
×
1271
                }
×
1272

1273
                // Validate each row is in the batch.
1274
                for _, row := range rows {
×
1275
                        scid := byteOrder.Uint64(row.Scid)
×
1276

×
1277
                        kvdbZombie, ok := batch[scid]
×
1278
                        if !ok {
×
1279
                                return fmt.Errorf("zombie SCID %x not found "+
×
1280
                                        "in batch", scid)
×
1281
                        }
×
1282

1283
                        err = sqldb.CompareRecords(
×
1284
                                kvdbZombie.pub1[:], row.NodeKey1,
×
1285
                                fmt.Sprintf("zombie pub key 1 (%s) for "+
×
1286
                                        "channel %d", kvdbZombie.pub1, scid),
×
1287
                        )
×
1288
                        if err != nil {
×
1289
                                return err
×
1290
                        }
×
1291

1292
                        err = sqldb.CompareRecords(
×
1293
                                kvdbZombie.pub2[:], row.NodeKey2,
×
1294
                                fmt.Sprintf("zombie pub key 2 (%s) for "+
×
1295
                                        "channel %d", kvdbZombie.pub2, scid),
×
1296
                        )
×
1297
                        if err != nil {
×
1298
                                return err
×
1299
                        }
×
1300
                }
1301

1302
                // Reset the batch for the next iteration.
1303
                batch = make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
1304

×
1305
                return nil
×
1306
        }
1307

1308
        err := forEachZombieEntry(kvBackend, func(chanID uint64, pubKey1,
×
1309
                pubKey2 [33]byte) error {
×
1310

×
1311
                chanIDB := channelIDToBytes(chanID)
×
1312

×
1313
                // If it is in the closed SCID index, we don't need to
×
1314
                // add it to the zombie index.
×
1315
                //
×
1316
                // NOTE: this means that the resulting zombie index count in
×
1317
                // the SQL store may well be less than the count of zombie
×
1318
                // channels in the KV store.
×
1319
                isClosed, err := sqlDB.IsClosedChannel(ctx, chanIDB)
×
1320
                if err != nil {
×
1321
                        return fmt.Errorf("could not check closed "+
×
1322
                                "channel: %w", err)
×
1323
                }
×
1324
                if isClosed {
×
1325
                        return nil
×
1326
                }
×
1327

1328
                count++
×
1329
                chunk++
×
1330

×
1331
                err = sqlDB.UpsertZombieChannel(
×
1332
                        ctx, sqlc.UpsertZombieChannelParams{
×
NEW
1333
                                Version:  int16(lnwire.GossipVersion1),
×
1334
                                Scid:     chanIDB,
×
1335
                                NodeKey1: pubKey1[:],
×
1336
                                NodeKey2: pubKey2[:],
×
1337
                        },
×
1338
                )
×
1339
                if err != nil {
×
1340
                        return fmt.Errorf("could not upsert zombie "+
×
1341
                                "channel %d: %w", chanID, err)
×
1342
                }
×
1343

1344
                // Add to validation batch only after successful insertion.
1345
                batch[chanID] = &zombieEntry{
×
1346
                        pub1: pubKey1,
×
1347
                        pub2: pubKey2,
×
1348
                }
×
1349

×
1350
                // Validate batch when full.
×
1351
                if len(batch) >= int(cfg.MaxBatchSize) {
×
1352
                        err := validateBatch()
×
1353
                        if err != nil {
×
1354
                                return fmt.Errorf("batch validation failed: %w",
×
1355
                                        err)
×
1356
                        }
×
1357
                }
1358

1359
                s.Do(func() {
×
1360
                        elapsed := time.Since(t0).Seconds()
×
1361
                        ratePerSec := float64(chunk) / elapsed
×
1362
                        log.Debugf("Migrated %d zombie index entries "+
×
1363
                                "(%.2f entries/sec)", count, ratePerSec)
×
1364

×
1365
                        t0 = time.Now()
×
1366
                        chunk = 0
×
1367
                })
×
1368

1369
                return nil
×
1370
        }, func() {
×
1371
                count = 0
×
1372
                chunk = 0
×
1373
                t0 = time.Now()
×
1374
                batch = make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
1375
        })
×
1376
        if err != nil {
×
1377
                return fmt.Errorf("could not migrate zombie index: %w", err)
×
1378
        }
×
1379

1380
        // Validate any remaining zombie SCIDs in the batch.
1381
        if len(batch) > 0 {
×
1382
                err := validateBatch()
×
1383
                if err != nil {
×
1384
                        return fmt.Errorf("final batch validation failed: %w",
×
1385
                                err)
×
1386
                }
×
1387
        }
1388

1389
        log.Infof("Migrated %d zombie channels from KV to SQL in %s", count,
×
1390
                time.Since(totalTime))
×
1391

×
1392
        return nil
×
1393
}
1394

1395
// forEachZombieEntry iterates over each zombie channel entry in the
1396
// KV backend and calls the provided callback function for each entry.
1397
func forEachZombieEntry(db kvdb.Backend, cb func(chanID uint64, pubKey1,
1398
        pubKey2 [33]byte) error, reset func()) error {
×
1399

×
1400
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
1401
                edges := tx.ReadBucket(edgeBucket)
×
1402
                if edges == nil {
×
1403
                        return ErrGraphNoEdgesFound
×
1404
                }
×
1405
                zombieIndex := edges.NestedReadBucket(zombieBucket)
×
1406
                if zombieIndex == nil {
×
1407
                        return nil
×
1408
                }
×
1409

1410
                return zombieIndex.ForEach(func(k, v []byte) error {
×
1411
                        var pubKey1, pubKey2 [33]byte
×
1412
                        copy(pubKey1[:], v[:33])
×
1413
                        copy(pubKey2[:], v[33:])
×
1414

×
1415
                        return cb(byteOrder.Uint64(k), pubKey1, pubKey2)
×
1416
                })
×
1417
        }, reset)
1418
}
1419

1420
// forEachClosedSCID iterates over each closed SCID in the KV backend and calls
1421
// the provided callback function for each SCID.
1422
func forEachClosedSCID(db kvdb.Backend,
1423
        cb func(lnwire.ShortChannelID) error, reset func()) error {
×
1424

×
1425
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
1426
                closedScids := tx.ReadBucket(closedScidBucket)
×
1427
                if closedScids == nil {
×
1428
                        return nil
×
1429
                }
×
1430

1431
                return closedScids.ForEach(func(k, _ []byte) error {
×
1432
                        return cb(lnwire.NewShortChanIDFromInt(
×
1433
                                byteOrder.Uint64(k),
×
1434
                        ))
×
1435
                })
×
1436
        }, reset)
1437
}
1438

1439
// insertNodeSQLMig inserts the node record into the database during the graph
1440
// SQL migration. No error is expected if the node already exists. Unlike the
1441
// main upsertNode function, this function does not require that a new node
1442
// update have a newer timestamp than the existing one. This is because we want
1443
// the migration to be idempotent and dont want to error out if we re-insert the
1444
// exact same node.
1445
func insertNodeSQLMig(ctx context.Context, db SQLQueries,
1446
        node *models.Node) (int64, error) {
×
1447

×
1448
        params := sqlc.InsertNodeMigParams{
×
NEW
1449
                Version: int16(lnwire.GossipVersion1),
×
1450
                PubKey:  node.PubKeyBytes[:],
×
1451
        }
×
1452

×
NEW
1453
        if node.HaveAnnouncement() {
×
1454
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
NEW
1455
                params.Color = sqldb.SQLStrValid(
×
NEW
1456
                        EncodeHexColor(node.Color.UnwrapOr(color.RGBA{})),
×
NEW
1457
                )
×
NEW
1458
                params.Alias = sqldb.SQLStrValid(node.Alias.UnwrapOr(""))
×
1459
                params.Signature = node.AuthSigBytes
×
1460
        }
×
1461

1462
        nodeID, err := db.InsertNodeMig(ctx, params)
×
1463
        if err != nil {
×
1464
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
1465
                        err)
×
1466
        }
×
1467

1468
        // We can exit here if we don't have the announcement yet.
NEW
1469
        if !node.HaveAnnouncement() {
×
1470
                return nodeID, nil
×
1471
        }
×
1472

1473
        // Insert the node's features.
1474
        for feature := range node.Features.Features() {
×
1475
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
1476
                        NodeID:     nodeID,
×
1477
                        FeatureBit: int32(feature),
×
1478
                })
×
1479
                if err != nil {
×
1480
                        return 0, fmt.Errorf("unable to insert node(%d) "+
×
1481
                                "feature(%v): %w", nodeID, feature, err)
×
1482
                }
×
1483
        }
1484

1485
        // Update the node's addresses.
1486
        newAddresses, err := collectAddressRecords(node.Addresses)
×
1487
        if err != nil {
×
1488
                return 0, err
×
1489
        }
×
1490

1491
        // Any remaining entries in newAddresses are new addresses that need to
1492
        // be added to the database for the first time.
1493
        for addrType, addrList := range newAddresses {
×
1494
                for position, addr := range addrList {
×
1495
                        err := db.UpsertNodeAddress(
×
1496
                                ctx, sqlc.UpsertNodeAddressParams{
×
1497
                                        NodeID:   nodeID,
×
1498
                                        Type:     int16(addrType),
×
1499
                                        Address:  addr,
×
1500
                                        Position: int32(position),
×
1501
                                },
×
1502
                        )
×
1503
                        if err != nil {
×
1504
                                return 0, fmt.Errorf("unable to insert "+
×
1505
                                        "node(%d) address(%v): %w", nodeID,
×
1506
                                        addr, err)
×
1507
                        }
×
1508
                }
1509
        }
1510

1511
        // Convert the flat extra opaque data into a map of TLV types to
1512
        // values.
1513
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
1514
        if err != nil {
×
1515
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
1516
                        err)
×
1517
        }
×
1518

1519
        // Insert the node's extra signed fields.
1520
        for tlvType, value := range extra {
×
1521
                err = db.UpsertNodeExtraType(
×
1522
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
1523
                                NodeID: nodeID,
×
1524
                                Type:   int64(tlvType),
×
1525
                                Value:  value,
×
1526
                        },
×
1527
                )
×
1528
                if err != nil {
×
1529
                        return 0, fmt.Errorf("unable to upsert node(%d) extra "+
×
1530
                                "signed field(%v): %w", nodeID, tlvType, err)
×
1531
                }
×
1532
        }
1533

1534
        return nodeID, nil
×
1535
}
1536

1537
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
1538
// channel.
1539
type dbChanInfo struct {
1540
        channelID int64
1541
        node1ID   int64
1542
        node2ID   int64
1543
}
1544

1545
// insertChannelMig inserts a new channel record into the database during the
1546
// graph SQL migration.
1547
func insertChannelMig(ctx context.Context, db SQLQueries,
1548
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
1549

×
1550
        // Make sure that at least a "shell" entry for each node is present in
×
1551
        // the nodes table.
×
1552
        //
×
1553
        // NOTE: we need this even during the SQL migration where nodes are
×
1554
        // migrated first because there are cases were some nodes may have
×
1555
        // been skipped due to invalid TLV data.
×
1556
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
1557
        if err != nil {
×
1558
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
1559
        }
×
1560

1561
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
1562
        if err != nil {
×
1563
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
1564
        }
×
1565

1566
        var capacity sql.NullInt64
×
1567
        if edge.Capacity != 0 {
×
1568
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
1569
        }
×
1570

1571
        createParams := sqlc.InsertChannelMigParams{
×
NEW
1572
                Version:     int16(lnwire.GossipVersion1),
×
1573
                Scid:        channelIDToBytes(edge.ChannelID),
×
1574
                NodeID1:     node1DBID,
×
1575
                NodeID2:     node2DBID,
×
1576
                Outpoint:    edge.ChannelPoint.String(),
×
1577
                Capacity:    capacity,
×
1578
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
1579
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
1580
        }
×
1581

×
1582
        if edge.AuthProof != nil {
×
1583
                proof := edge.AuthProof
×
1584

×
1585
                createParams.Node1Signature = proof.NodeSig1Bytes
×
1586
                createParams.Node2Signature = proof.NodeSig2Bytes
×
1587
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
1588
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
1589
        }
×
1590

1591
        // Insert the new channel record.
1592
        dbChanID, err := db.InsertChannelMig(ctx, createParams)
×
1593
        if err != nil {
×
1594
                return nil, err
×
1595
        }
×
1596

1597
        // Insert any channel features.
1598
        for feature := range edge.Features.Features() {
×
1599
                err = db.InsertChannelFeature(
×
1600
                        ctx, sqlc.InsertChannelFeatureParams{
×
1601
                                ChannelID:  dbChanID,
×
1602
                                FeatureBit: int32(feature),
×
1603
                        },
×
1604
                )
×
1605
                if err != nil {
×
1606
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
1607
                                "feature(%v): %w", dbChanID, feature, err)
×
1608
                }
×
1609
        }
1610

1611
        // Finally, insert any extra TLV fields in the channel announcement.
1612
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1613
        if err != nil {
×
1614
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
1615
                        "data: %w", err)
×
1616
        }
×
1617

1618
        for tlvType, value := range extra {
×
1619
                err := db.UpsertChannelExtraType(
×
1620
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
1621
                                ChannelID: dbChanID,
×
1622
                                Type:      int64(tlvType),
×
1623
                                Value:     value,
×
1624
                        },
×
1625
                )
×
1626
                if err != nil {
×
1627
                        return nil, fmt.Errorf("unable to upsert "+
×
1628
                                "channel(%d) extra signed field(%v): %w",
×
1629
                                edge.ChannelID, tlvType, err)
×
1630
                }
×
1631
        }
1632

1633
        return &dbChanInfo{
×
1634
                channelID: dbChanID,
×
1635
                node1ID:   node1DBID,
×
1636
                node2ID:   node2DBID,
×
1637
        }, nil
×
1638
}
1639

1640
// insertChanEdgePolicyMig inserts the channel policy info we have stored for
1641
// a channel we already know of. This is used during the SQL migration
1642
// process to insert channel policies.
1643
func insertChanEdgePolicyMig(ctx context.Context, tx SQLQueries,
1644
        dbChan *dbChanInfo, edge *models.ChannelEdgePolicy) error {
×
1645

×
1646
        // Figure out which node this edge is from.
×
1647
        isNode1 := edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
1648
        nodeID := dbChan.node1ID
×
1649
        if !isNode1 {
×
1650
                nodeID = dbChan.node2ID
×
1651
        }
×
1652

1653
        var (
×
1654
                inboundBase sql.NullInt64
×
1655
                inboundRate sql.NullInt64
×
1656
        )
×
1657
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
1658
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
1659
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
1660
        })
×
1661

1662
        id, err := tx.InsertEdgePolicyMig(ctx, sqlc.InsertEdgePolicyMigParams{
×
NEW
1663
                Version:     int16(lnwire.GossipVersion1),
×
1664
                ChannelID:   dbChan.channelID,
×
1665
                NodeID:      nodeID,
×
1666
                Timelock:    int32(edge.TimeLockDelta),
×
1667
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
1668
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
1669
                MinHtlcMsat: int64(edge.MinHTLC),
×
1670
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
1671
                Disabled: sql.NullBool{
×
1672
                        Valid: true,
×
1673
                        Bool:  edge.IsDisabled(),
×
1674
                },
×
1675
                MaxHtlcMsat: sql.NullInt64{
×
1676
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
1677
                        Int64: int64(edge.MaxHTLC),
×
1678
                },
×
1679
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
1680
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
1681
                InboundBaseFeeMsat:      inboundBase,
×
1682
                InboundFeeRateMilliMsat: inboundRate,
×
1683
                Signature:               edge.SigBytes,
×
1684
        })
×
1685
        if err != nil {
×
1686
                return fmt.Errorf("unable to upsert edge policy: %w", err)
×
1687
        }
×
1688

1689
        // Convert the flat extra opaque data into a map of TLV types to
1690
        // values.
1691
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
1692
        if err != nil {
×
1693
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
1694
                        err)
×
1695
        }
×
1696

1697
        // Insert all new extra signed fields for the channel policy.
1698
        for tlvType, value := range extra {
×
1699
                err = tx.UpsertChanPolicyExtraType(
×
1700
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
1701
                                ChannelPolicyID: id,
×
1702
                                Type:            int64(tlvType),
×
1703
                                Value:           value,
×
1704
                        },
×
1705
                )
×
1706
                if err != nil {
×
1707
                        return fmt.Errorf("unable to insert "+
×
1708
                                "channel_policy(%d) extra signed field(%v): %w",
×
1709
                                id, tlvType, err)
×
1710
                }
×
1711
        }
1712

1713
        return nil
×
1714
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc