• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 16986782393

15 Aug 2025 09:02AM UTC coverage: 66.662% (-0.1%) from 66.763%
16986782393

Pull #10161

github

web-flow
Merge 5c41339ab into 365f1788e
Pull Request #10161: graph/db+sqldb: Make the SQL migration retry-safe/idempotent

0 of 336 new or added lines in 3 files covered. (0.0%)

87 existing lines in 20 files now uncovered.

135935 of 203917 relevant lines covered (66.66%)

21459.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_migration.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "cmp"
6
        "context"
7
        "database/sql"
8
        "errors"
9
        "fmt"
10
        "net"
11
        "slices"
12
        "time"
13

14
        "github.com/btcsuite/btcd/chaincfg/chainhash"
15
        "github.com/lightningnetwork/lnd/graph/db/models"
16
        "github.com/lightningnetwork/lnd/kvdb"
17
        "github.com/lightningnetwork/lnd/lnwire"
18
        "github.com/lightningnetwork/lnd/routing/route"
19
        "github.com/lightningnetwork/lnd/sqldb"
20
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
21
        "golang.org/x/time/rate"
22
)
23

24
// MigrateGraphToSQL migrates the graph store from a KV backend to a SQL
25
// backend.
26
//
27
// NOTE: this is currently not called from any code path. It is called via tests
28
// only for now and will be called from the main lnd binary once the
29
// migration is fully implemented and tested.
30
func MigrateGraphToSQL(ctx context.Context, cfg *SQLStoreConfig,
31
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
32

×
33
        log.Infof("Starting migration of the graph store from KV to SQL")
×
34
        t0 := time.Now()
×
35

×
36
        // Check if there is a graph to migrate.
×
37
        graphExists, err := checkGraphExists(kvBackend)
×
38
        if err != nil {
×
39
                return fmt.Errorf("failed to check graph existence: %w", err)
×
40
        }
×
41
        if !graphExists {
×
42
                log.Infof("No graph found in KV store, skipping the migration")
×
43
                return nil
×
44
        }
×
45

46
        // 1) Migrate all the nodes.
47
        err = migrateNodes(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
48
        if err != nil {
×
49
                return fmt.Errorf("could not migrate nodes: %w", err)
×
50
        }
×
51

52
        // 2) Migrate the source node.
53
        if err := migrateSourceNode(ctx, kvBackend, sqlDB); err != nil {
×
54
                return fmt.Errorf("could not migrate source node: %w", err)
×
55
        }
×
56

57
        // 3) Migrate all the channels and channel policies.
58
        err = migrateChannelsAndPolicies(ctx, cfg, kvBackend, sqlDB)
×
59
        if err != nil {
×
60
                return fmt.Errorf("could not migrate channels and policies: %w",
×
61
                        err)
×
62
        }
×
63

64
        // 4) Migrate the Prune log.
65
        err = migratePruneLog(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
66
        if err != nil {
×
67
                return fmt.Errorf("could not migrate prune log: %w", err)
×
68
        }
×
69

70
        // 5) Migrate the closed SCID index.
71
        err = migrateClosedSCIDIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
72
        if err != nil {
×
73
                return fmt.Errorf("could not migrate closed SCID index: %w",
×
74
                        err)
×
75
        }
×
76

77
        // 6) Migrate the zombie index.
78
        err = migrateZombieIndex(ctx, cfg.QueryCfg, kvBackend, sqlDB)
×
79
        if err != nil {
×
80
                return fmt.Errorf("could not migrate zombie index: %w", err)
×
81
        }
×
82

83
        log.Infof("Finished migration of the graph store from KV to SQL in %v",
×
84
                time.Since(t0))
×
85

×
86
        return nil
×
87
}
88

89
// checkGraphExists checks if the graph exists in the KV backend.
90
func checkGraphExists(db kvdb.Backend) (bool, error) {
×
91
        // Check if there is even a graph to migrate.
×
92
        err := db.View(func(tx kvdb.RTx) error {
×
93
                // Check for the existence of the node bucket which is a top
×
94
                // level bucket that would have been created on the initial
×
95
                // creation of the graph store.
×
96
                nodes := tx.ReadBucket(nodeBucket)
×
97
                if nodes == nil {
×
98
                        return ErrGraphNotFound
×
99
                }
×
100

101
                return nil
×
102
        }, func() {})
×
103
        if errors.Is(err, ErrGraphNotFound) {
×
104
                return false, nil
×
105
        } else if err != nil {
×
106
                return false, err
×
107
        }
×
108

109
        return true, nil
×
110
}
111

112
// migrateNodes migrates all nodes from the KV backend to the SQL database.
113
// It collects nodes in batches, inserts them individually, and then validates
114
// them in batches.
115
func migrateNodes(ctx context.Context, cfg *sqldb.QueryConfig,
116
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
117

×
118
        // Keep track of the number of nodes migrated and the number of
×
119
        // nodes skipped due to errors.
×
120
        var (
×
121
                totalTime = time.Now()
×
122

×
123
                count   uint64
×
124
                skipped uint64
×
125

×
126
                t0    = time.Now()
×
127
                chunk uint64
×
128
                s     = rate.Sometimes{
×
129
                        Interval: 10 * time.Second,
×
130
                }
×
131
        )
×
132

×
133
        // batch is a map that holds node objects that have been migrated to
×
134
        // the native SQL store that have yet to be validated. The object's held
×
135
        // by this map were derived from the KVDB store and so when they are
×
136
        // validated, the map index (the SQL store node ID) will be used to
×
137
        // fetch the corresponding node object in the SQL store, and it will
×
138
        // then be compared against the original KVDB node object.
×
139
        batch := make(
×
140
                map[int64]*models.LightningNode, cfg.MaxBatchSize,
×
141
        )
×
142

×
143
        // validateBatch validates that the batch of nodes in the 'batch' map
×
144
        // have been migrated successfully.
×
145
        validateBatch := func() error {
×
146
                if len(batch) == 0 {
×
147
                        return nil
×
148
                }
×
149

150
                // Extract DB node IDs.
151
                dbIDs := make([]int64, 0, len(batch))
×
152
                for dbID := range batch {
×
153
                        dbIDs = append(dbIDs, dbID)
×
154
                }
×
155

156
                // Batch fetch all nodes from the database.
157
                dbNodes, err := sqlDB.GetNodesByIDs(ctx, dbIDs)
×
158
                if err != nil {
×
159
                        return fmt.Errorf("could not batch fetch nodes: %w",
×
160
                                err)
×
161
                }
×
162

163
                // Make sure that the number of nodes fetched matches the number
164
                // of nodes in the batch.
165
                if len(dbNodes) != len(batch) {
×
166
                        return fmt.Errorf("expected to fetch %d nodes, "+
×
167
                                "but got %d", len(batch), len(dbNodes))
×
168
                }
×
169

170
                // Now, batch fetch the normalised data for all the nodes in
171
                // the batch.
172
                batchData, err := batchLoadNodeData(ctx, cfg, sqlDB, dbIDs)
×
173
                if err != nil {
×
174
                        return fmt.Errorf("unable to batch load node data: %w",
×
175
                                err)
×
176
                }
×
177

178
                for _, dbNode := range dbNodes {
×
179
                        // Get the KVDB node info from the batch map.
×
180
                        node, ok := batch[dbNode.ID]
×
181
                        if !ok {
×
182
                                return fmt.Errorf("node with ID %d not found "+
×
183
                                        "in batch", dbNode.ID)
×
184
                        }
×
185

186
                        // Build the migrated node from the DB node and the
187
                        // batch node data.
188
                        migNode, err := buildNodeWithBatchData(
×
189
                                dbNode, batchData,
×
190
                        )
×
191
                        if err != nil {
×
192
                                return fmt.Errorf("could not build migrated "+
×
193
                                        "node from dbNode(db id: %d, node "+
×
194
                                        "pub: %x): %w", dbNode.ID,
×
195
                                        node.PubKeyBytes, err)
×
196
                        }
×
197

198
                        // Make sure that the node addresses are sorted before
199
                        // comparing them to ensure that the order of addresses
200
                        // does not affect the comparison.
201
                        slices.SortFunc(
×
202
                                node.Addresses, func(i, j net.Addr) int {
×
203
                                        return cmp.Compare(
×
204
                                                i.String(), j.String(),
×
205
                                        )
×
206
                                },
×
207
                        )
208
                        slices.SortFunc(
×
209
                                migNode.Addresses, func(i, j net.Addr) int {
×
210
                                        return cmp.Compare(
×
211
                                                i.String(), j.String(),
×
212
                                        )
×
213
                                },
×
214
                        )
215

216
                        err = sqldb.CompareRecords(
×
217
                                node, migNode,
×
218
                                fmt.Sprintf("node %x", node.PubKeyBytes),
×
219
                        )
×
220
                        if err != nil {
×
221
                                return fmt.Errorf("node mismatch after "+
×
222
                                        "migration for node %x: %w",
×
223
                                        node.PubKeyBytes, err)
×
224
                        }
×
225
                }
226

227
                // Clear the batch map for the next iteration.
228
                batch = make(
×
229
                        map[int64]*models.LightningNode, cfg.MaxBatchSize,
×
230
                )
×
231

×
232
                return nil
×
233
        }
234

235
        // Loop through each node in the KV store and insert it into the SQL
236
        // database.
237
        err := forEachNode(kvBackend, func(_ kvdb.RTx,
×
238
                node *models.LightningNode) error {
×
239

×
240
                pub := node.PubKeyBytes
×
241

×
242
                // Sanity check to ensure that the node has valid extra opaque
×
243
                // data. If it does not, we'll skip it. We need to do this
×
244
                // because previously we would just persist any TLV bytes that
×
245
                // we received without validating them. Now, however, we
×
246
                // normalise the storage of extra opaque data, so we need to
×
247
                // ensure that the data is valid. We don't want to abort the
×
248
                // migration if we encounter a node with invalid extra opaque
×
249
                // data, so we'll just skip it and log a warning.
×
250
                _, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
251
                if errors.Is(err, ErrParsingExtraTLVBytes) {
×
252
                        skipped++
×
253
                        log.Warnf("Skipping migration of node %x with invalid "+
×
254
                                "extra opaque data: %v", pub,
×
255
                                node.ExtraOpaqueData)
×
256

×
257
                        return nil
×
258
                } else if err != nil {
×
259
                        return fmt.Errorf("unable to marshal extra "+
×
260
                                "opaque data for node %x: %w", pub, err)
×
261
                }
×
262

263
                count++
×
264
                chunk++
×
265

×
266
                // TODO(elle): At this point, we should check the loaded node
×
267
                // to see if we should extract any DNS addresses from its
×
268
                // opaque type addresses. This is expected to be done in:
×
269
                // https://github.com/lightningnetwork/lnd/pull/9455.
×
270
                // This TODO is being tracked in
×
271
                //  https://github.com/lightningnetwork/lnd/issues/9795 as this
×
272
                // must be addressed before making this code path active in
×
273
                // production.
×
274

×
275
                // Write the node to the SQL database.
×
NEW
276
                id, err := insertNodeSQLMig(ctx, sqlDB, node)
×
277
                if err != nil {
×
278
                        return fmt.Errorf("could not persist node(%x): %w", pub,
×
279
                                err)
×
280
                }
×
281

282
                // Add to validation batch.
283
                batch[id] = node
×
284

×
285
                // Validate batch when full.
×
286
                if len(batch) >= int(cfg.MaxBatchSize) {
×
287
                        err := validateBatch()
×
288
                        if err != nil {
×
289
                                return fmt.Errorf("batch validation failed: %w",
×
290
                                        err)
×
291
                        }
×
292
                }
293

294
                s.Do(func() {
×
295
                        elapsed := time.Since(t0).Seconds()
×
296
                        ratePerSec := float64(chunk) / elapsed
×
297
                        log.Debugf("Migrated %d nodes (%.2f nodes/sec)",
×
298
                                count, ratePerSec)
×
299

×
300
                        t0 = time.Now()
×
301
                        chunk = 0
×
302
                })
×
303

304
                return nil
×
305
        }, func() {
×
NEW
306
                count = 0
×
NEW
307
                chunk = 0
×
NEW
308
                skipped = 0
×
NEW
309
                t0 = time.Now()
×
NEW
310
                batch = make(map[int64]*models.LightningNode, cfg.MaxBatchSize)
×
311
        })
×
312
        if err != nil {
×
313
                return fmt.Errorf("could not migrate nodes: %w", err)
×
314
        }
×
315

316
        // Validate any remaining nodes in the batch.
317
        if len(batch) > 0 {
×
318
                err := validateBatch()
×
319
                if err != nil {
×
320
                        return fmt.Errorf("final batch validation failed: %w",
×
321
                                err)
×
322
                }
×
323
        }
324

325
        log.Infof("Migrated %d nodes from KV to SQL in %v (skipped %d nodes "+
×
326
                "due to invalid TLV streams)", count, time.Since(totalTime),
×
327
                skipped)
×
328

×
329
        return nil
×
330
}
331

332
// migrateSourceNode migrates the source node from the KV backend to the
333
// SQL database.
334
func migrateSourceNode(ctx context.Context, kvdb kvdb.Backend,
335
        sqlDB SQLQueries) error {
×
336

×
337
        log.Debugf("Migrating source node from KV to SQL")
×
338

×
339
        sourceNode, err := sourceNode(kvdb)
×
340
        if errors.Is(err, ErrSourceNodeNotSet) {
×
341
                // If the source node has not been set yet, we can skip this
×
342
                // migration step.
×
343
                return nil
×
344
        } else if err != nil {
×
345
                return fmt.Errorf("could not get source node from kv "+
×
346
                        "store: %w", err)
×
347
        }
×
348

349
        pub := sourceNode.PubKeyBytes
×
350

×
351
        // Get the DB ID of the source node by its public key. This node must
×
352
        // already exist in the SQL database, as it should have been migrated
×
353
        // in the previous node-migration step.
×
354
        id, err := sqlDB.GetNodeIDByPubKey(
×
355
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
356
                        PubKey:  pub[:],
×
357
                        Version: int16(ProtocolV1),
×
358
                },
×
359
        )
×
360
        if err != nil {
×
361
                return fmt.Errorf("could not get source node ID: %w", err)
×
362
        }
×
363

364
        // Now we can add the source node to the SQL database.
365
        err = sqlDB.AddSourceNode(ctx, id)
×
366
        if err != nil {
×
367
                return fmt.Errorf("could not add source node to SQL store: %w",
×
368
                        err)
×
369
        }
×
370

371
        // Verify that the source node was added correctly by fetching it back
372
        // from the SQL database and checking that the expected DB ID and
373
        // pub key are returned. We don't need to do a whole node comparison
374
        // here, as this was already done in the previous migration step.
375
        srcNodes, err := sqlDB.GetSourceNodesByVersion(ctx, int16(ProtocolV1))
×
376
        if err != nil {
×
377
                return fmt.Errorf("could not get source nodes from SQL "+
×
378
                        "store: %w", err)
×
379
        }
×
380

381
        // The SQL store has support for multiple source nodes (for future
382
        // protocol versions) but this migration is purely aimed at the V1
383
        // store, and so we expect exactly one source node to be present.
384
        if len(srcNodes) != 1 {
×
385
                return fmt.Errorf("expected exactly one source node, "+
×
386
                        "got %d", len(srcNodes))
×
387
        }
×
388

389
        // Check that the source node ID and pub key match the original
390
        // source node.
391
        if srcNodes[0].NodeID != id {
×
392
                return fmt.Errorf("source node ID mismatch after migration: "+
×
393
                        "expected %d, got %d", id, srcNodes[0].NodeID)
×
394
        }
×
395
        err = sqldb.CompareRecords(pub[:], srcNodes[0].PubKey, "source node")
×
396
        if err != nil {
×
397
                return fmt.Errorf("source node pubkey mismatch after "+
×
398
                        "migration: %w", err)
×
399
        }
×
400

401
        log.Infof("Migrated source node with pubkey %x to SQL", pub[:])
×
402

×
403
        return nil
×
404
}
405

406
// migChanInfo holds the information about a channel and its policies.
407
type migChanInfo struct {
408
        // edge is the channel object as read from the KVDB source.
409
        edge *models.ChannelEdgeInfo
410

411
        // policy1 is the first channel policy for the channel as read from
412
        // the KVDB source.
413
        policy1 *models.ChannelEdgePolicy
414

415
        // policy2 is the second channel policy for the channel as read
416
        // from the KVDB source.
417
        policy2 *models.ChannelEdgePolicy
418

419
        // dbInfo holds location info (in the form of DB IDs) of the channel
420
        // and its policies in the native-SQL destination.
421
        dbInfo *dbChanInfo
422
}
423

424
// migrateChannelsAndPolicies migrates all channels and their policies
425
// from the KV backend to the SQL database.
426
func migrateChannelsAndPolicies(ctx context.Context, cfg *SQLStoreConfig,
427
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
428

×
429
        var (
×
430
                totalTime = time.Now()
×
431

×
432
                channelCount       uint64
×
433
                skippedChanCount   uint64
×
434
                policyCount        uint64
×
435
                skippedPolicyCount uint64
×
436

×
437
                t0    = time.Now()
×
438
                chunk uint64
×
439
                s     = rate.Sometimes{
×
440
                        Interval: 10 * time.Second,
×
441
                }
×
442
        )
×
NEW
443
        migChanPolicy := func(dbChanInfo *dbChanInfo,
×
NEW
444
                policy *models.ChannelEdgePolicy) error {
×
NEW
445

×
446
                // If the policy is nil, we can skip it.
×
447
                if policy == nil {
×
448
                        return nil
×
449
                }
×
450

451
                // Unlike the special case of invalid TLV bytes for node and
452
                // channel announcements, we don't need to handle the case for
453
                // channel policies here because it is already handled in the
454
                // `forEachChannel` function. If the policy has invalid TLV
455
                // bytes, then `nil` will be passed to this function.
456

457
                policyCount++
×
458

×
NEW
459
                err := insertChanEdgePolicyMig(ctx, sqlDB, dbChanInfo, policy)
×
460
                if err != nil {
×
461
                        return fmt.Errorf("could not migrate channel "+
×
462
                                "policy %d: %w", policy.ChannelID, err)
×
463
                }
×
464

465
                return nil
×
466
        }
467

468
        // batch is used to collect migrated channel info that we will
469
        // batch-validate. Each entry is indexed by the DB ID of the channel
470
        // in the SQL database.
471
        batch := make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
472

×
473
        // Iterate over each channel in the KV store and migrate it and its
×
474
        // policies to the SQL database.
×
475
        err := forEachChannel(kvBackend, func(channel *models.ChannelEdgeInfo,
×
476
                policy1 *models.ChannelEdgePolicy,
×
477
                policy2 *models.ChannelEdgePolicy) error {
×
478

×
479
                scid := channel.ChannelID
×
480

×
481
                // Here, we do a sanity check to ensure that the chain hash of
×
482
                // the channel returned by the KV store matches the expected
×
483
                // chain hash. This is important since in the SQL store, we will
×
484
                // no longer explicitly store the chain hash in the channel
×
485
                // info, but rather rely on the chain hash LND is running with.
×
486
                // So this is our way of ensuring that LND is running on the
×
487
                // correct network at migration time.
×
488
                if channel.ChainHash != cfg.ChainHash {
×
489
                        return fmt.Errorf("channel %d has chain hash %s, "+
×
490
                                "expected %s", scid, channel.ChainHash,
×
491
                                cfg.ChainHash)
×
492
                }
×
493

494
                // Sanity check to ensure that the channel has valid extra
495
                // opaque data. If it does not, we'll skip it. We need to do
496
                // this because previously we would just persist any TLV bytes
497
                // that we received without validating them. Now, however, we
498
                // normalise the storage of extra opaque data, so we need to
499
                // ensure that the data is valid. We don't want to abort the
500
                // migration if we encounter a channel with invalid extra opaque
501
                // data, so we'll just skip it and log a warning.
502
                _, err := marshalExtraOpaqueData(channel.ExtraOpaqueData)
×
503
                if errors.Is(err, ErrParsingExtraTLVBytes) {
×
504
                        log.Warnf("Skipping channel %d with invalid "+
×
505
                                "extra opaque data: %v", scid,
×
506
                                channel.ExtraOpaqueData)
×
507

×
508
                        skippedChanCount++
×
509

×
510
                        // If we skip a channel, we also skip its policies.
×
511
                        if policy1 != nil {
×
512
                                skippedPolicyCount++
×
513
                        }
×
514
                        if policy2 != nil {
×
515
                                skippedPolicyCount++
×
516
                        }
×
517

518
                        return nil
×
519
                } else if err != nil {
×
520
                        return fmt.Errorf("unable to marshal extra opaque "+
×
521
                                "data for channel %d (%v): %w", scid,
×
522
                                channel.ExtraOpaqueData, err)
×
523
                }
×
524

525
                channelCount++
×
526
                chunk++
×
527

×
528
                // Migrate the channel info along with its policies.
×
NEW
529
                dbChanInfo, err := insertChannelMig(ctx, sqlDB, channel)
×
530
                if err != nil {
×
531
                        return fmt.Errorf("could not insert record for "+
×
532
                                "channel %d in SQL store: %w", scid, err)
×
533
                }
×
534

535
                // Now, migrate the two channel policies for the channel.
NEW
536
                err = migChanPolicy(dbChanInfo, policy1)
×
537
                if err != nil {
×
538
                        return fmt.Errorf("could not migrate policy1(%d): %w",
×
539
                                scid, err)
×
540
                }
×
NEW
541
                err = migChanPolicy(dbChanInfo, policy2)
×
542
                if err != nil {
×
543
                        return fmt.Errorf("could not migrate policy2(%d): %w",
×
544
                                scid, err)
×
545
                }
×
546

547
                // Collect the migrated channel info and policies in a batch for
548
                // later validation.
549
                batch[dbChanInfo.channelID] = &migChanInfo{
×
550
                        edge:    channel,
×
551
                        policy1: policy1,
×
552
                        policy2: policy2,
×
553
                        dbInfo:  dbChanInfo,
×
554
                }
×
555

×
556
                if len(batch) >= int(cfg.QueryCfg.MaxBatchSize) {
×
557
                        // Do batch validation.
×
558
                        err := validateMigratedChannels(ctx, cfg, sqlDB, batch)
×
559
                        if err != nil {
×
560
                                return fmt.Errorf("could not validate "+
×
561
                                        "channel batch: %w", err)
×
562
                        }
×
563

564
                        batch = make(
×
565
                                map[int64]*migChanInfo,
×
566
                                cfg.QueryCfg.MaxBatchSize,
×
567
                        )
×
568
                }
569

570
                s.Do(func() {
×
571
                        elapsed := time.Since(t0).Seconds()
×
572
                        ratePerSec := float64(chunk) / elapsed
×
573
                        log.Debugf("Migrated %d channels (%.2f channels/sec)",
×
574
                                channelCount, ratePerSec)
×
575

×
576
                        t0 = time.Now()
×
577
                        chunk = 0
×
578
                })
×
579

580
                return nil
×
581
        }, func() {
×
NEW
582
                channelCount = 0
×
NEW
583
                policyCount = 0
×
NEW
584
                chunk = 0
×
NEW
585
                skippedChanCount = 0
×
NEW
586
                skippedPolicyCount = 0
×
NEW
587
                t0 = time.Now()
×
NEW
588
                batch = make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
589
        })
×
590
        if err != nil {
×
591
                return fmt.Errorf("could not migrate channels and policies: %w",
×
592
                        err)
×
593
        }
×
594

595
        if len(batch) > 0 {
×
596
                // Do a final batch validation for any remaining channels.
×
597
                err := validateMigratedChannels(ctx, cfg, sqlDB, batch)
×
598
                if err != nil {
×
599
                        return fmt.Errorf("could not validate final channel "+
×
600
                                "batch: %w", err)
×
601
                }
×
602

603
                batch = make(map[int64]*migChanInfo, cfg.QueryCfg.MaxBatchSize)
×
604
        }
605

606
        log.Infof("Migrated %d channels and %d policies from KV to SQL in %s"+
×
607
                "(skipped %d channels and %d policies due to invalid TLV "+
×
608
                "streams)", channelCount, policyCount, time.Since(totalTime),
×
609
                skippedChanCount, skippedPolicyCount)
×
610

×
611
        return nil
×
612
}
613

614
// validateMigratedChannels validates the channels in the batch after they have
615
// been migrated to the SQL database. It batch fetches all channels by their IDs
616
// and compares the migrated channels and their policies with the original ones
617
// to ensure they match using batch construction patterns.
618
func validateMigratedChannels(ctx context.Context, cfg *SQLStoreConfig,
619
        sqlDB SQLQueries, batch map[int64]*migChanInfo) error {
×
620

×
621
        // Convert batch keys (DB IDs) to an int slice for the batch query.
×
622
        dbChanIDs := make([]int64, 0, len(batch))
×
623
        for id := range batch {
×
624
                dbChanIDs = append(dbChanIDs, id)
×
625
        }
×
626

627
        // Batch fetch all channels with their policies.
628
        rows, err := sqlDB.GetChannelsByIDs(ctx, dbChanIDs)
×
629
        if err != nil {
×
630
                return fmt.Errorf("could not batch get channels by IDs: %w",
×
631
                        err)
×
632
        }
×
633

634
        // Sanity check that the same number of channels were returned
635
        // as requested.
636
        if len(rows) != len(dbChanIDs) {
×
637
                return fmt.Errorf("expected to fetch %d channels, "+
×
638
                        "but got %d", len(dbChanIDs), len(rows))
×
639
        }
×
640

641
        // Collect all policy IDs needed for batch data loading.
642
        dbPolicyIDs := make([]int64, 0, len(dbChanIDs)*2)
×
643

×
644
        for _, row := range rows {
×
645
                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
646

×
647
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
648
                if err != nil {
×
649
                        return fmt.Errorf("could not extract channel policies"+
×
650
                                " for SCID %d: %w", scid, err)
×
651
                }
×
652
                if dbPol1 != nil {
×
653
                        dbPolicyIDs = append(dbPolicyIDs, dbPol1.ID)
×
654
                }
×
655
                if dbPol2 != nil {
×
656
                        dbPolicyIDs = append(dbPolicyIDs, dbPol2.ID)
×
657
                }
×
658
        }
659

660
        // Batch load all channel and policy data (features, extras).
661
        batchData, err := batchLoadChannelData(
×
662
                ctx, cfg.QueryCfg, sqlDB, dbChanIDs, dbPolicyIDs,
×
663
        )
×
664
        if err != nil {
×
665
                return fmt.Errorf("could not batch load channel and policy "+
×
666
                        "data: %w", err)
×
667
        }
×
668

669
        // Validate each channel in the batch using pre-loaded data.
670
        for _, row := range rows {
×
671
                kvdbChan, ok := batch[row.GraphChannel.ID]
×
672
                if !ok {
×
673
                        return fmt.Errorf("channel with ID %d not found "+
×
674
                                "in batch", row.GraphChannel.ID)
×
675
                }
×
676

677
                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
678

×
679
                err = validateMigratedChannelWithBatchData(
×
680
                        cfg, scid, kvdbChan, row, batchData,
×
681
                )
×
682
                if err != nil {
×
683
                        return fmt.Errorf("channel %d validation failed "+
×
684
                                "after migration: %w", scid, err)
×
685
                }
×
686
        }
687

688
        return nil
×
689
}
690

691
// validateMigratedChannelWithBatchData validates a single migrated channel
692
// using pre-fetched batch data for optimal performance.
693
func validateMigratedChannelWithBatchData(cfg *SQLStoreConfig,
694
        scid uint64, info *migChanInfo, row sqlc.GetChannelsByIDsRow,
695
        batchData *batchChannelData) error {
×
696

×
697
        dbChanInfo := info.dbInfo
×
698
        channel := info.edge
×
699

×
700
        // Assert that the DB IDs for the channel and nodes are as expected
×
701
        // given the inserted channel info.
×
702
        err := sqldb.CompareRecords(
×
703
                dbChanInfo.channelID, row.GraphChannel.ID, "channel DB ID",
×
704
        )
×
705
        if err != nil {
×
706
                return err
×
707
        }
×
708
        err = sqldb.CompareRecords(
×
709
                dbChanInfo.node1ID, row.Node1ID, "node1 DB ID",
×
710
        )
×
711
        if err != nil {
×
712
                return err
×
713
        }
×
714
        err = sqldb.CompareRecords(
×
715
                dbChanInfo.node2ID, row.Node2ID, "node2 DB ID",
×
716
        )
×
717
        if err != nil {
×
718
                return err
×
719
        }
×
720

721
        // Build node vertices from the row data.
722
        node1, node2, err := buildNodeVertices(
×
723
                row.Node1PubKey, row.Node2PubKey,
×
724
        )
×
725
        if err != nil {
×
726
                return err
×
727
        }
×
728

729
        // Build channel info using batch data.
730
        migChan, err := buildEdgeInfoWithBatchData(
×
731
                cfg.ChainHash, row.GraphChannel, node1, node2, batchData,
×
732
        )
×
733
        if err != nil {
×
734
                return fmt.Errorf("could not build migrated channel info: %w",
×
735
                        err)
×
736
        }
×
737

738
        // Extract channel policies from the row.
739
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
740
        if err != nil {
×
741
                return fmt.Errorf("could not extract channel policies: %w", err)
×
742
        }
×
743

744
        // Build channel policies using batch data.
745
        migPol1, migPol2, err := buildChanPoliciesWithBatchData(
×
746
                dbPol1, dbPol2, scid, node1, node2, batchData,
×
747
        )
×
748
        if err != nil {
×
749
                return fmt.Errorf("could not build migrated channel "+
×
750
                        "policies: %w", err)
×
751
        }
×
752

753
        // Finally, compare the original channel info and
754
        // policies with the migrated ones to ensure they match.
755
        if len(channel.ExtraOpaqueData) == 0 {
×
756
                channel.ExtraOpaqueData = nil
×
757
        }
×
758
        if len(migChan.ExtraOpaqueData) == 0 {
×
759
                migChan.ExtraOpaqueData = nil
×
760
        }
×
761

762
        err = sqldb.CompareRecords(
×
763
                channel, migChan, fmt.Sprintf("channel %d", scid),
×
764
        )
×
765
        if err != nil {
×
766
                return err
×
767
        }
×
768

769
        checkPolicy := func(expPolicy,
×
770
                migPolicy *models.ChannelEdgePolicy) error {
×
771

×
772
                switch {
×
773
                // Both policies are nil, nothing to compare.
774
                case expPolicy == nil && migPolicy == nil:
×
775
                        return nil
×
776

777
                // One of the policies is nil, but the other is not.
778
                case expPolicy == nil || migPolicy == nil:
×
779
                        return fmt.Errorf("expected both policies to be "+
×
780
                                "non-nil. Got expPolicy: %v, "+
×
781
                                "migPolicy: %v", expPolicy, migPolicy)
×
782

783
                // Both policies are non-nil, we can compare them.
784
                default:
×
785
                }
786

787
                if len(expPolicy.ExtraOpaqueData) == 0 {
×
788
                        expPolicy.ExtraOpaqueData = nil
×
789
                }
×
790
                if len(migPolicy.ExtraOpaqueData) == 0 {
×
791
                        migPolicy.ExtraOpaqueData = nil
×
792
                }
×
793

794
                return sqldb.CompareRecords(
×
795
                        *expPolicy, *migPolicy, "channel policy",
×
796
                )
×
797
        }
798

799
        err = checkPolicy(info.policy1, migPol1)
×
800
        if err != nil {
×
801
                return fmt.Errorf("policy1 mismatch for channel %d: %w", scid,
×
802
                        err)
×
803
        }
×
804

805
        err = checkPolicy(info.policy2, migPol2)
×
806
        if err != nil {
×
807
                return fmt.Errorf("policy2 mismatch for channel %d: %w", scid,
×
808
                        err)
×
809
        }
×
810

811
        return nil
×
812
}
813

814
// migratePruneLog migrates the prune log from the KV backend to the SQL
815
// database. It collects entries in batches, inserts them individually, and then
816
// validates them in batches using GetPruneEntriesForHeights for better i
817
// performance.
818
func migratePruneLog(ctx context.Context, cfg *sqldb.QueryConfig,
819
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
820

×
821
        var (
×
822
                totalTime = time.Now()
×
823

×
824
                count          uint64
×
825
                pruneTipHeight uint32
×
826
                pruneTipHash   chainhash.Hash
×
827

×
828
                t0    = time.Now()
×
829
                chunk uint64
×
830
                s     = rate.Sometimes{
×
831
                        Interval: 10 * time.Second,
×
832
                }
×
833
        )
×
834

×
835
        batch := make(map[uint32]chainhash.Hash, cfg.MaxBatchSize)
×
836

×
837
        // validateBatch validates a batch of prune entries using batch query.
×
838
        validateBatch := func() error {
×
839
                if len(batch) == 0 {
×
840
                        return nil
×
841
                }
×
842

843
                // Extract heights for the batch query.
844
                heights := make([]int64, 0, len(batch))
×
845
                for height := range batch {
×
846
                        heights = append(heights, int64(height))
×
847
                }
×
848

849
                // Batch fetch all entries from the database.
850
                rows, err := sqlDB.GetPruneEntriesForHeights(ctx, heights)
×
851
                if err != nil {
×
852
                        return fmt.Errorf("could not batch get prune "+
×
853
                                "entries: %w", err)
×
854
                }
×
855

856
                if len(rows) != len(batch) {
×
857
                        return fmt.Errorf("expected to fetch %d prune "+
×
858
                                "entries, but got %d", len(batch),
×
859
                                len(rows))
×
860
                }
×
861

862
                // Validate each entry in the batch.
863
                for _, row := range rows {
×
864
                        kvdbHash, ok := batch[uint32(row.BlockHeight)]
×
865
                        if !ok {
×
866
                                return fmt.Errorf("prune entry for height %d "+
×
867
                                        "not found in batch", row.BlockHeight)
×
868
                        }
×
869

870
                        err := sqldb.CompareRecords(
×
871
                                kvdbHash[:], row.BlockHash,
×
872
                                fmt.Sprintf("prune log entry at height %d",
×
873
                                        row.BlockHash),
×
874
                        )
×
875
                        if err != nil {
×
876
                                return err
×
877
                        }
×
878
                }
879

880
                // Reset the batch map for the next iteration.
881
                batch = make(map[uint32]chainhash.Hash, cfg.MaxBatchSize)
×
882

×
883
                return nil
×
884
        }
885

886
        // Iterate over each prune log entry in the KV store and migrate it to
887
        // the SQL database.
888
        err := forEachPruneLogEntry(
×
889
                kvBackend, func(height uint32, hash *chainhash.Hash) error {
×
890
                        count++
×
891
                        chunk++
×
892

×
893
                        // Keep track of the prune tip height and hash.
×
894
                        if height > pruneTipHeight {
×
895
                                pruneTipHeight = height
×
896
                                pruneTipHash = *hash
×
897
                        }
×
898

899
                        // Insert the entry (individual inserts for now).
900
                        err := sqlDB.UpsertPruneLogEntry(
×
901
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
902
                                        BlockHeight: int64(height),
×
903
                                        BlockHash:   hash[:],
×
904
                                },
×
905
                        )
×
906
                        if err != nil {
×
907
                                return fmt.Errorf("unable to insert prune log "+
×
908
                                        "entry for height %d: %w", height, err)
×
909
                        }
×
910

911
                        // Add to validation batch.
912
                        batch[height] = *hash
×
913

×
914
                        // Validate batch when full.
×
915
                        if len(batch) >= int(cfg.MaxBatchSize) {
×
916
                                err := validateBatch()
×
917
                                if err != nil {
×
918
                                        return fmt.Errorf("batch "+
×
919
                                                "validation failed: %w", err)
×
920
                                }
×
921
                        }
922

923
                        s.Do(func() {
×
924
                                elapsed := time.Since(t0).Seconds()
×
925
                                ratePerSec := float64(chunk) / elapsed
×
926
                                log.Debugf("Migrated %d prune log "+
×
927
                                        "entries (%.2f entries/sec)",
×
928
                                        count, ratePerSec)
×
929

×
930
                                t0 = time.Now()
×
931
                                chunk = 0
×
932
                        })
×
933

934
                        return nil
×
935
                },
NEW
936
                func() {
×
NEW
937
                        count = 0
×
NEW
938
                        chunk = 0
×
NEW
939
                        t0 = time.Now()
×
NEW
940
                        batch = make(
×
NEW
941
                                map[uint32]chainhash.Hash, cfg.MaxBatchSize,
×
NEW
942
                        )
×
NEW
943
                },
×
944
        )
945
        if err != nil {
×
946
                return fmt.Errorf("could not migrate prune log: %w", err)
×
947
        }
×
948

949
        // Validate any remaining entries in the batch.
950
        if len(batch) > 0 {
×
951
                err := validateBatch()
×
952
                if err != nil {
×
953
                        return fmt.Errorf("final batch validation failed: %w",
×
954
                                err)
×
955
                }
×
956
        }
957

958
        // Check that the prune tip is set correctly in the SQL
959
        // database.
960
        pruneTip, err := sqlDB.GetPruneTip(ctx)
×
961
        if errors.Is(err, sql.ErrNoRows) {
×
962
                // The ErrGraphNeverPruned error is expected if no prune log
×
963
                // entries were migrated from the kvdb store. Otherwise, it's
×
964
                // an unexpected error.
×
965
                if count == 0 {
×
966
                        log.Infof("No prune log entries found in KV store " +
×
967
                                "to migrate")
×
968
                        return nil
×
969
                }
×
970
                // Fall-through to the next error check.
971
        }
972
        if err != nil {
×
973
                return fmt.Errorf("could not get prune tip: %w", err)
×
974
        }
×
975

976
        if pruneTip.BlockHeight != int64(pruneTipHeight) ||
×
977
                !bytes.Equal(pruneTip.BlockHash, pruneTipHash[:]) {
×
978

×
979
                return fmt.Errorf("prune tip mismatch after migration: "+
×
980
                        "expected height %d, hash %s; got height %d, "+
×
981
                        "hash %s", pruneTipHeight, pruneTipHash,
×
982
                        pruneTip.BlockHeight,
×
983
                        chainhash.Hash(pruneTip.BlockHash))
×
984
        }
×
985

986
        log.Infof("Migrated %d prune log entries from KV to SQL in %s. "+
×
987
                "The prune tip is: height %d, hash: %s", count,
×
988
                time.Since(totalTime), pruneTipHeight, pruneTipHash)
×
989

×
990
        return nil
×
991
}
992

993
// forEachPruneLogEntry iterates over each prune log entry in the KV
994
// backend and calls the provided callback function for each entry.
995
func forEachPruneLogEntry(db kvdb.Backend, cb func(height uint32,
NEW
996
        hash *chainhash.Hash) error, reset func()) error {
×
997

×
998
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
999
                metaBucket := tx.ReadBucket(graphMetaBucket)
×
1000
                if metaBucket == nil {
×
1001
                        return ErrGraphNotFound
×
1002
                }
×
1003

1004
                pruneBucket := metaBucket.NestedReadBucket(pruneLogBucket)
×
1005
                if pruneBucket == nil {
×
1006
                        // The graph has never been pruned and so, there are no
×
1007
                        // entries to iterate over.
×
1008
                        return nil
×
1009
                }
×
1010

1011
                return pruneBucket.ForEach(func(k, v []byte) error {
×
1012
                        blockHeight := byteOrder.Uint32(k)
×
1013
                        var blockHash chainhash.Hash
×
1014
                        copy(blockHash[:], v)
×
1015

×
1016
                        return cb(blockHeight, &blockHash)
×
1017
                })
×
1018
        }, reset)
1019
}
1020

1021
// migrateClosedSCIDIndex migrates the closed SCID index from the KV backend to
1022
// the SQL database. It collects SCIDs in batches, inserts them individually,
1023
// and then validates them in batches using GetClosedChannelsSCIDs for better
1024
// performance.
1025
func migrateClosedSCIDIndex(ctx context.Context, cfg *sqldb.QueryConfig,
1026
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
1027

×
1028
        var (
×
1029
                totalTime = time.Now()
×
1030

×
1031
                count uint64
×
1032

×
1033
                t0    = time.Now()
×
1034
                chunk uint64
×
1035
                s     = rate.Sometimes{
×
1036
                        Interval: 10 * time.Second,
×
1037
                }
×
1038
        )
×
1039

×
1040
        batch := make([][]byte, 0, cfg.MaxBatchSize)
×
1041

×
1042
        // validateBatch validates a batch of closed SCIDs using batch query.
×
1043
        validateBatch := func() error {
×
1044
                if len(batch) == 0 {
×
1045
                        return nil
×
1046
                }
×
1047

1048
                // Batch fetch all closed SCIDs from the database.
1049
                dbSCIDs, err := sqlDB.GetClosedChannelsSCIDs(ctx, batch)
×
1050
                if err != nil {
×
1051
                        return fmt.Errorf("could not batch get closed "+
×
1052
                                "SCIDs: %w", err)
×
1053
                }
×
1054

1055
                // Create set of SCIDs that exist in the database for quick
1056
                // lookup.
1057
                dbSCIDSet := make(map[string]struct{})
×
1058
                for _, scid := range dbSCIDs {
×
1059
                        dbSCIDSet[string(scid)] = struct{}{}
×
1060
                }
×
1061

1062
                // Validate each SCID in the batch.
1063
                for _, expectedSCID := range batch {
×
1064
                        if _, found := dbSCIDSet[string(expectedSCID)]; !found {
×
1065
                                return fmt.Errorf("closed SCID %x not found "+
×
1066
                                        "in database", expectedSCID)
×
1067
                        }
×
1068
                }
1069

1070
                // Reset the batch for the next iteration.
1071
                batch = make([][]byte, 0, cfg.MaxBatchSize)
×
1072

×
1073
                return nil
×
1074
        }
1075

1076
        migrateSingleClosedSCID := func(scid lnwire.ShortChannelID) error {
×
1077
                count++
×
1078
                chunk++
×
1079

×
1080
                chanIDB := channelIDToBytes(scid.ToUint64())
×
1081
                err := sqlDB.InsertClosedChannel(ctx, chanIDB)
×
1082
                if err != nil {
×
1083
                        return fmt.Errorf("could not insert closed channel "+
×
1084
                                "with SCID %s: %w", scid, err)
×
1085
                }
×
1086

1087
                // Add to validation batch.
1088
                batch = append(batch, chanIDB)
×
1089

×
1090
                // Validate batch when full.
×
1091
                if len(batch) >= int(cfg.MaxBatchSize) {
×
1092
                        err := validateBatch()
×
1093
                        if err != nil {
×
1094
                                return fmt.Errorf("batch validation failed: %w",
×
1095
                                        err)
×
1096
                        }
×
1097
                }
1098

1099
                s.Do(func() {
×
1100
                        elapsed := time.Since(t0).Seconds()
×
1101
                        ratePerSec := float64(chunk) / elapsed
×
1102
                        log.Debugf("Migrated %d closed scids "+
×
1103
                                "(%.2f entries/sec)", count, ratePerSec)
×
1104

×
1105
                        t0 = time.Now()
×
1106
                        chunk = 0
×
1107
                })
×
1108

1109
                return nil
×
1110
        }
1111

NEW
1112
        err := forEachClosedSCID(
×
NEW
1113
                kvBackend, migrateSingleClosedSCID, func() {
×
NEW
1114
                        count = 0
×
NEW
1115
                        chunk = 0
×
NEW
1116
                        t0 = time.Now()
×
NEW
1117
                        batch = make([][]byte, 0, cfg.MaxBatchSize)
×
NEW
1118
                },
×
1119
        )
1120
        if err != nil {
×
1121
                return fmt.Errorf("could not migrate closed SCID index: %w",
×
1122
                        err)
×
1123
        }
×
1124

1125
        // Validate any remaining SCIDs in the batch.
1126
        if len(batch) > 0 {
×
1127
                err := validateBatch()
×
1128
                if err != nil {
×
1129
                        return fmt.Errorf("final batch validation failed: %w",
×
1130
                                err)
×
1131
                }
×
1132
        }
1133

1134
        log.Infof("Migrated %d closed SCIDs from KV to SQL in %s", count,
×
1135
                time.Since(totalTime))
×
1136

×
1137
        return nil
×
1138
}
1139

1140
// migrateZombieIndex migrates the zombie index from the KV backend to the SQL
1141
// database. It collects zombie channels in batches, inserts them individually,
1142
// and validates them in batches.
1143
//
1144
// NOTE: before inserting an entry into the zombie index, the function checks
1145
// if the channel is already marked as closed in the SQL store. If it is,
1146
// the entry is skipped. This means that the resulting zombie index count in
1147
// the SQL store may well be less than the count of zombie channels in the KV
1148
// store.
1149
func migrateZombieIndex(ctx context.Context, cfg *sqldb.QueryConfig,
1150
        kvBackend kvdb.Backend, sqlDB SQLQueries) error {
×
1151

×
1152
        var (
×
1153
                totalTime = time.Now()
×
1154

×
1155
                count uint64
×
1156

×
1157
                t0    = time.Now()
×
1158
                chunk uint64
×
1159
                s     = rate.Sometimes{
×
1160
                        Interval: 10 * time.Second,
×
1161
                }
×
1162
        )
×
1163

×
1164
        type zombieEntry struct {
×
1165
                pub1 route.Vertex
×
1166
                pub2 route.Vertex
×
1167
        }
×
1168

×
1169
        batch := make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
1170

×
1171
        // validateBatch validates a batch of zombie SCIDs using batch query.
×
1172
        validateBatch := func() error {
×
1173
                if len(batch) == 0 {
×
1174
                        return nil
×
1175
                }
×
1176

1177
                scids := make([][]byte, 0, len(batch))
×
1178
                for scid := range batch {
×
1179
                        scids = append(scids, channelIDToBytes(scid))
×
1180
                }
×
1181

1182
                // Batch fetch all zombie channels from the database.
1183
                rows, err := sqlDB.GetZombieChannelsSCIDs(
×
1184
                        ctx, sqlc.GetZombieChannelsSCIDsParams{
×
1185
                                Version: int16(ProtocolV1),
×
1186
                                Scids:   scids,
×
1187
                        },
×
1188
                )
×
1189
                if err != nil {
×
1190
                        return fmt.Errorf("could not batch get zombie "+
×
1191
                                "SCIDs: %w", err)
×
1192
                }
×
1193

1194
                // Make sure that the number of rows returned matches
1195
                // the number of SCIDs we requested.
1196
                if len(rows) != len(scids) {
×
1197
                        return fmt.Errorf("expected to fetch %d zombie "+
×
1198
                                "SCIDs, but got %d", len(scids), len(rows))
×
1199
                }
×
1200

1201
                // Validate each row is in the batch.
1202
                for _, row := range rows {
×
1203
                        scid := byteOrder.Uint64(row.Scid)
×
1204

×
1205
                        kvdbZombie, ok := batch[scid]
×
1206
                        if !ok {
×
1207
                                return fmt.Errorf("zombie SCID %x not found "+
×
1208
                                        "in batch", scid)
×
1209
                        }
×
1210

1211
                        err = sqldb.CompareRecords(
×
1212
                                kvdbZombie.pub1[:], row.NodeKey1,
×
1213
                                fmt.Sprintf("zombie pub key 1 (%s) for "+
×
1214
                                        "channel %d", kvdbZombie.pub1, scid),
×
1215
                        )
×
1216
                        if err != nil {
×
1217
                                return err
×
1218
                        }
×
1219

1220
                        err = sqldb.CompareRecords(
×
1221
                                kvdbZombie.pub2[:], row.NodeKey2,
×
1222
                                fmt.Sprintf("zombie pub key 2 (%s) for "+
×
1223
                                        "channel %d", kvdbZombie.pub2, scid),
×
1224
                        )
×
1225
                        if err != nil {
×
1226
                                return err
×
1227
                        }
×
1228
                }
1229

1230
                // Reset the batch for the next iteration.
1231
                batch = make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
1232

×
1233
                return nil
×
1234
        }
1235

1236
        err := forEachZombieEntry(kvBackend, func(chanID uint64, pubKey1,
×
1237
                pubKey2 [33]byte) error {
×
1238

×
1239
                chanIDB := channelIDToBytes(chanID)
×
1240

×
1241
                // If it is in the closed SCID index, we don't need to
×
1242
                // add it to the zombie index.
×
1243
                //
×
1244
                // NOTE: this means that the resulting zombie index count in
×
1245
                // the SQL store may well be less than the count of zombie
×
1246
                // channels in the KV store.
×
1247
                isClosed, err := sqlDB.IsClosedChannel(ctx, chanIDB)
×
1248
                if err != nil {
×
1249
                        return fmt.Errorf("could not check closed "+
×
1250
                                "channel: %w", err)
×
1251
                }
×
1252
                if isClosed {
×
1253
                        return nil
×
1254
                }
×
1255

1256
                count++
×
1257
                chunk++
×
1258

×
1259
                err = sqlDB.UpsertZombieChannel(
×
1260
                        ctx, sqlc.UpsertZombieChannelParams{
×
1261
                                Version:  int16(ProtocolV1),
×
1262
                                Scid:     chanIDB,
×
1263
                                NodeKey1: pubKey1[:],
×
1264
                                NodeKey2: pubKey2[:],
×
1265
                        },
×
1266
                )
×
1267
                if err != nil {
×
1268
                        return fmt.Errorf("could not upsert zombie "+
×
1269
                                "channel %d: %w", chanID, err)
×
1270
                }
×
1271

1272
                // Add to validation batch only after successful insertion.
1273
                batch[chanID] = &zombieEntry{
×
1274
                        pub1: pubKey1,
×
1275
                        pub2: pubKey2,
×
1276
                }
×
1277

×
1278
                // Validate batch when full.
×
1279
                if len(batch) >= int(cfg.MaxBatchSize) {
×
1280
                        err := validateBatch()
×
1281
                        if err != nil {
×
1282
                                return fmt.Errorf("batch validation failed: %w",
×
1283
                                        err)
×
1284
                        }
×
1285
                }
1286

1287
                s.Do(func() {
×
1288
                        elapsed := time.Since(t0).Seconds()
×
1289
                        ratePerSec := float64(chunk) / elapsed
×
1290
                        log.Debugf("Migrated %d zombie index entries "+
×
1291
                                "(%.2f entries/sec)", count, ratePerSec)
×
1292

×
1293
                        t0 = time.Now()
×
1294
                        chunk = 0
×
1295
                })
×
1296

1297
                return nil
×
NEW
1298
        }, func() {
×
NEW
1299
                count = 0
×
NEW
1300
                chunk = 0
×
NEW
1301
                t0 = time.Now()
×
NEW
1302
                batch = make(map[uint64]*zombieEntry, cfg.MaxBatchSize)
×
1303
        })
×
1304
        if err != nil {
×
1305
                return fmt.Errorf("could not migrate zombie index: %w", err)
×
1306
        }
×
1307

1308
        // Validate any remaining zombie SCIDs in the batch.
1309
        if len(batch) > 0 {
×
1310
                err := validateBatch()
×
1311
                if err != nil {
×
1312
                        return fmt.Errorf("final batch validation failed: %w",
×
1313
                                err)
×
1314
                }
×
1315
        }
1316

1317
        log.Infof("Migrated %d zombie channels from KV to SQL in %s", count,
×
1318
                time.Since(totalTime))
×
1319

×
1320
        return nil
×
1321
}
1322

1323
// forEachZombieEntry iterates over each zombie channel entry in the
1324
// KV backend and calls the provided callback function for each entry.
1325
func forEachZombieEntry(db kvdb.Backend, cb func(chanID uint64, pubKey1,
NEW
1326
        pubKey2 [33]byte) error, reset func()) error {
×
1327

×
1328
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
1329
                edges := tx.ReadBucket(edgeBucket)
×
1330
                if edges == nil {
×
1331
                        return ErrGraphNoEdgesFound
×
1332
                }
×
1333
                zombieIndex := edges.NestedReadBucket(zombieBucket)
×
1334
                if zombieIndex == nil {
×
1335
                        return nil
×
1336
                }
×
1337

1338
                return zombieIndex.ForEach(func(k, v []byte) error {
×
1339
                        var pubKey1, pubKey2 [33]byte
×
1340
                        copy(pubKey1[:], v[:33])
×
1341
                        copy(pubKey2[:], v[33:])
×
1342

×
1343
                        return cb(byteOrder.Uint64(k), pubKey1, pubKey2)
×
1344
                })
×
1345
        }, reset)
1346
}
1347

1348
// forEachClosedSCID iterates over each closed SCID in the KV backend and calls
1349
// the provided callback function for each SCID.
1350
func forEachClosedSCID(db kvdb.Backend,
NEW
1351
        cb func(lnwire.ShortChannelID) error, reset func()) error {
×
1352

×
1353
        return kvdb.View(db, func(tx kvdb.RTx) error {
×
1354
                closedScids := tx.ReadBucket(closedScidBucket)
×
1355
                if closedScids == nil {
×
1356
                        return nil
×
1357
                }
×
1358

1359
                return closedScids.ForEach(func(k, _ []byte) error {
×
1360
                        return cb(lnwire.NewShortChanIDFromInt(
×
1361
                                byteOrder.Uint64(k),
×
1362
                        ))
×
1363
                })
×
1364
        }, reset)
1365
}
1366

1367
// insertNodeSQLMig inserts the node record into the database during the graph
1368
// SQL migration. No error is expected if the node already exists. Unlike the
1369
// main upsertNode function, this function does not require that a new node
1370
// update have a newer timestamp than the existing one. This is because we want
1371
// the migration to be idempotent and dont want to error out if we re-insert the
1372
// exact same node.
1373
func insertNodeSQLMig(ctx context.Context, db SQLQueries,
NEW
1374
        node *models.LightningNode) (int64, error) {
×
NEW
1375

×
NEW
1376
        params := sqlc.InsertNodeMigParams{
×
NEW
1377
                Version: int16(ProtocolV1),
×
NEW
1378
                PubKey:  node.PubKeyBytes[:],
×
NEW
1379
        }
×
NEW
1380

×
NEW
1381
        if node.HaveNodeAnnouncement {
×
NEW
1382
                params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
NEW
1383
                params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
×
NEW
1384
                params.Alias = sqldb.SQLStr(node.Alias)
×
NEW
1385
                params.Signature = node.AuthSigBytes
×
NEW
1386
        }
×
1387

NEW
1388
        nodeID, err := db.InsertNodeMig(ctx, params)
×
NEW
1389
        if err != nil {
×
NEW
1390
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
NEW
1391
                        err)
×
NEW
1392
        }
×
1393

1394
        // We can exit here if we don't have the announcement yet.
NEW
1395
        if !node.HaveNodeAnnouncement {
×
NEW
1396
                return nodeID, nil
×
NEW
1397
        }
×
1398

1399
        // Insert the node's features.
NEW
1400
        for feature := range node.Features.Features() {
×
NEW
1401
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
NEW
1402
                        NodeID:     nodeID,
×
NEW
1403
                        FeatureBit: int32(feature),
×
NEW
1404
                })
×
NEW
1405
                if err != nil {
×
NEW
1406
                        return 0, fmt.Errorf("unable to insert node(%d) "+
×
NEW
1407
                                "feature(%v): %w", nodeID, feature, err)
×
NEW
1408
                }
×
1409
        }
1410

1411
        // Update the node's addresses.
NEW
1412
        newAddresses, err := collectAddressRecords(node.Addresses)
×
NEW
1413
        if err != nil {
×
NEW
1414
                return 0, err
×
NEW
1415
        }
×
1416

1417
        // Any remaining entries in newAddresses are new addresses that need to
1418
        // be added to the database for the first time.
NEW
1419
        for addrType, addrList := range newAddresses {
×
NEW
1420
                for position, addr := range addrList {
×
NEW
1421
                        err := db.InsertNodeAddress(
×
NEW
1422
                                ctx, sqlc.InsertNodeAddressParams{
×
NEW
1423
                                        NodeID:   nodeID,
×
NEW
1424
                                        Type:     int16(addrType),
×
NEW
1425
                                        Address:  addr,
×
NEW
1426
                                        Position: int32(position),
×
NEW
1427
                                },
×
NEW
1428
                        )
×
NEW
1429
                        if err != nil {
×
NEW
1430
                                return 0, fmt.Errorf("unable to insert "+
×
NEW
1431
                                        "node(%d) address(%v): %w", nodeID,
×
NEW
1432
                                        addr, err)
×
NEW
1433
                        }
×
1434
                }
1435
        }
1436

1437
        // Convert the flat extra opaque data into a map of TLV types to
1438
        // values.
NEW
1439
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
NEW
1440
        if err != nil {
×
NEW
1441
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
NEW
1442
                        err)
×
NEW
1443
        }
×
1444

1445
        // Insert the node's extra signed fields.
NEW
1446
        for tlvType, value := range extra {
×
NEW
1447
                err = db.UpsertNodeExtraType(
×
NEW
1448
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
NEW
1449
                                NodeID: nodeID,
×
NEW
1450
                                Type:   int64(tlvType),
×
NEW
1451
                                Value:  value,
×
NEW
1452
                        },
×
NEW
1453
                )
×
NEW
1454
                if err != nil {
×
NEW
1455
                        return 0, fmt.Errorf("unable to upsert node(%d) extra "+
×
NEW
1456
                                "signed field(%v): %w", nodeID, tlvType, err)
×
NEW
1457
                }
×
1458
        }
1459

NEW
1460
        return nodeID, nil
×
1461
}
1462

1463
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
1464
// channel.
1465
type dbChanInfo struct {
1466
        channelID int64
1467
        node1ID   int64
1468
        node2ID   int64
1469
}
1470

1471
// insertChannelMig inserts a new channel record into the database during the
1472
// graph SQL migration.
1473
func insertChannelMig(ctx context.Context, db SQLQueries,
NEW
1474
        edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
×
NEW
1475

×
NEW
1476
        // Make sure that at least a "shell" entry for each node is present in
×
NEW
1477
        // the nodes table.
×
NEW
1478
        //
×
NEW
1479
        // NOTE: we need this even during the SQL migration where nodes are
×
NEW
1480
        // migrated first because there are cases were some nodes may have
×
NEW
1481
        // been skipped due to invalid TLV data.
×
NEW
1482
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
NEW
1483
        if err != nil {
×
NEW
1484
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
NEW
1485
        }
×
1486

NEW
1487
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
NEW
1488
        if err != nil {
×
NEW
1489
                return nil, fmt.Errorf("unable to create shell node: %w", err)
×
NEW
1490
        }
×
1491

NEW
1492
        var capacity sql.NullInt64
×
NEW
1493
        if edge.Capacity != 0 {
×
NEW
1494
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
NEW
1495
        }
×
1496

NEW
1497
        createParams := sqlc.InsertChannelMigParams{
×
NEW
1498
                Version:     int16(ProtocolV1),
×
NEW
1499
                Scid:        channelIDToBytes(edge.ChannelID),
×
NEW
1500
                NodeID1:     node1DBID,
×
NEW
1501
                NodeID2:     node2DBID,
×
NEW
1502
                Outpoint:    edge.ChannelPoint.String(),
×
NEW
1503
                Capacity:    capacity,
×
NEW
1504
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
NEW
1505
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
NEW
1506
        }
×
NEW
1507

×
NEW
1508
        if edge.AuthProof != nil {
×
NEW
1509
                proof := edge.AuthProof
×
NEW
1510

×
NEW
1511
                createParams.Node1Signature = proof.NodeSig1Bytes
×
NEW
1512
                createParams.Node2Signature = proof.NodeSig2Bytes
×
NEW
1513
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
NEW
1514
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
NEW
1515
        }
×
1516

1517
        // Insert the new channel record.
NEW
1518
        dbChanID, err := db.InsertChannelMig(ctx, createParams)
×
NEW
1519
        if err != nil {
×
NEW
1520
                return nil, err
×
NEW
1521
        }
×
1522

1523
        // Insert any channel features.
NEW
1524
        for feature := range edge.Features.Features() {
×
NEW
1525
                err = db.InsertChannelFeature(
×
NEW
1526
                        ctx, sqlc.InsertChannelFeatureParams{
×
NEW
1527
                                ChannelID:  dbChanID,
×
NEW
1528
                                FeatureBit: int32(feature),
×
NEW
1529
                        },
×
NEW
1530
                )
×
NEW
1531
                if err != nil {
×
NEW
1532
                        return nil, fmt.Errorf("unable to insert channel(%d) "+
×
NEW
1533
                                "feature(%v): %w", dbChanID, feature, err)
×
NEW
1534
                }
×
1535
        }
1536

1537
        // Finally, insert any extra TLV fields in the channel announcement.
NEW
1538
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
NEW
1539
        if err != nil {
×
NEW
1540
                return nil, fmt.Errorf("unable to marshal extra opaque "+
×
NEW
1541
                        "data: %w", err)
×
NEW
1542
        }
×
1543

NEW
1544
        for tlvType, value := range extra {
×
NEW
1545
                err := db.CreateChannelExtraType(
×
NEW
1546
                        ctx, sqlc.CreateChannelExtraTypeParams{
×
NEW
1547
                                ChannelID: dbChanID,
×
NEW
1548
                                Type:      int64(tlvType),
×
NEW
1549
                                Value:     value,
×
NEW
1550
                        },
×
NEW
1551
                )
×
NEW
1552
                if err != nil {
×
NEW
1553
                        return nil, fmt.Errorf("unable to upsert "+
×
NEW
1554
                                "channel(%d) extra signed field(%v): %w",
×
NEW
1555
                                edge.ChannelID, tlvType, err)
×
NEW
1556
                }
×
1557
        }
1558

NEW
1559
        return &dbChanInfo{
×
NEW
1560
                channelID: dbChanID,
×
NEW
1561
                node1ID:   node1DBID,
×
NEW
1562
                node2ID:   node2DBID,
×
NEW
1563
        }, nil
×
1564
}
1565

1566
// insertChanEdgePolicyMig inserts the channel policy info we have stored for
1567
// a channel we already know of. This is used during the SQL migration
1568
// process to insert channel policies.
1569
func insertChanEdgePolicyMig(ctx context.Context, tx SQLQueries,
NEW
1570
        dbChan *dbChanInfo, edge *models.ChannelEdgePolicy) error {
×
NEW
1571

×
NEW
1572
        // Figure out which node this edge is from.
×
NEW
1573
        isNode1 := edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
NEW
1574
        nodeID := dbChan.node1ID
×
NEW
1575
        if !isNode1 {
×
NEW
1576
                nodeID = dbChan.node2ID
×
NEW
1577
        }
×
1578

NEW
1579
        var (
×
NEW
1580
                inboundBase sql.NullInt64
×
NEW
1581
                inboundRate sql.NullInt64
×
NEW
1582
        )
×
NEW
1583
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
NEW
1584
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
NEW
1585
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
NEW
1586
        })
×
1587

NEW
1588
        id, err := tx.InsertEdgePolicyMig(ctx, sqlc.InsertEdgePolicyMigParams{
×
NEW
1589
                Version:     int16(ProtocolV1),
×
NEW
1590
                ChannelID:   dbChan.channelID,
×
NEW
1591
                NodeID:      nodeID,
×
NEW
1592
                Timelock:    int32(edge.TimeLockDelta),
×
NEW
1593
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
NEW
1594
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
NEW
1595
                MinHtlcMsat: int64(edge.MinHTLC),
×
NEW
1596
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
NEW
1597
                Disabled: sql.NullBool{
×
NEW
1598
                        Valid: true,
×
NEW
1599
                        Bool:  edge.IsDisabled(),
×
NEW
1600
                },
×
NEW
1601
                MaxHtlcMsat: sql.NullInt64{
×
NEW
1602
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
NEW
1603
                        Int64: int64(edge.MaxHTLC),
×
NEW
1604
                },
×
NEW
1605
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
NEW
1606
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
NEW
1607
                InboundBaseFeeMsat:      inboundBase,
×
NEW
1608
                InboundFeeRateMilliMsat: inboundRate,
×
NEW
1609
                Signature:               edge.SigBytes,
×
NEW
1610
        })
×
NEW
1611
        if err != nil {
×
NEW
1612
                return fmt.Errorf("unable to upsert edge policy: %w", err)
×
NEW
1613
        }
×
1614

1615
        // Convert the flat extra opaque data into a map of TLV types to
1616
        // values.
NEW
1617
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
NEW
1618
        if err != nil {
×
NEW
1619
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
NEW
1620
                        err)
×
NEW
1621
        }
×
1622

1623
        // Insert all new extra signed fields for the channel policy.
NEW
1624
        for tlvType, value := range extra {
×
NEW
1625
                err = tx.InsertChanPolicyExtraType(
×
NEW
1626
                        ctx, sqlc.InsertChanPolicyExtraTypeParams{
×
NEW
1627
                                ChannelPolicyID: id,
×
NEW
1628
                                Type:            int64(tlvType),
×
NEW
1629
                                Value:           value,
×
NEW
1630
                        },
×
NEW
1631
                )
×
NEW
1632
                if err != nil {
×
NEW
1633
                        return fmt.Errorf("unable to insert "+
×
NEW
1634
                                "channel_policy(%d) extra signed field(%v): %w",
×
NEW
1635
                                id, tlvType, err)
×
NEW
1636
                }
×
1637
        }
1638

NEW
1639
        return nil
×
1640
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc