• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 21994242646

13 Feb 2026 04:22PM UTC coverage: 65.12% (+0.2%) from 64.883%
21994242646

push

github

web-flow
Merge pull request #10414 from lightningnetwork/elle-g175Prep-base

[g175] graph/db: merge g175 types-prep side branch

673 of 1704 new or added lines in 23 files covered. (39.5%)

104 existing lines in 23 files now uncovered.

139183 of 213734 relevant lines covered (65.12%)

20622.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
46
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
47
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
48
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
49
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
50
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
51
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
52
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
53
        IsPublicV2Node(ctx context.Context, pubKey []byte) (bool, error)
54
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
55
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
56
        DeleteNode(ctx context.Context, id int64) error
57
        NodeExists(ctx context.Context, arg sqlc.NodeExistsParams) (bool, error)
58

59
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
60
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
61
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
62
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
63

64
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
65
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
66
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
67
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
68

69
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
70
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
71
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
72
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
73
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
74

75
        /*
76
                Source node queries.
77
        */
78
        AddSourceNode(ctx context.Context, nodeID int64) error
79
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
80

81
        /*
82
                Channel queries.
83
        */
84
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
85
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
86
        AddV2ChannelProof(ctx context.Context, arg sqlc.AddV2ChannelProofParams) (sql.Result, error)
87
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
88
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
89
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
90
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
91
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
92
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
93
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
94
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
95
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
96
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
97
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
98
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
99
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
100
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
101
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
102
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
103
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
104
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
105
        DeleteChannels(ctx context.Context, ids []int64) error
106

107
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
108
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
109
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
110
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
111

112
        /*
113
                Channel Policy table queries.
114
        */
115
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
116
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
117
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
118

119
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
120
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
121
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
122

123
        /*
124
                Zombie index queries.
125
        */
126
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
127
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
128
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
129
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
130
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
131
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
132

133
        /*
134
                Prune log table queries.
135
        */
136
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
137
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
138
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
139
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
140
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
141

142
        /*
143
                Closed SCID table queries.
144
        */
145
        InsertClosedChannel(ctx context.Context, scid []byte) error
146
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
147
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
148

149
        /*
150
                Migration specific queries.
151

152
                NOTE: these should not be used in code other than migrations.
153
                Once sqldbv2 is in place, these can be removed from this struct
154
                as then migrations will have their own dedicated queries
155
                structs.
156
        */
157
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
158
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
159
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
160
}
161

162
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
163
// database operations.
164
type BatchedSQLQueries interface {
165
        SQLQueries
166
        sqldb.BatchedTx[SQLQueries]
167
}
168

169
// SQLStore is an implementation of the Store interface that uses a SQL
170
// database as the backend.
171
type SQLStore struct {
172
        cfg *SQLStoreConfig
173
        db  BatchedSQLQueries
174

175
        // cacheMu guards all caches (rejectCache and chanCache). If
176
        // this mutex will be acquired at the same time as the DB mutex then
177
        // the cacheMu MUST be acquired first to prevent deadlock.
178
        cacheMu     sync.RWMutex
179
        rejectCache *rejectCache
180
        chanCache   *channelCache
181

182
        chanScheduler batch.Scheduler[SQLQueries]
183
        nodeScheduler batch.Scheduler[SQLQueries]
184

185
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
186
        srcNodeMu sync.Mutex
187
}
188

189
// A compile-time assertion to ensure that SQLStore implements the Store
190
// interface.
191
var _ Store = (*SQLStore)(nil)
192

193
// SQLStoreConfig holds the configuration for the SQLStore.
194
type SQLStoreConfig struct {
195
        // ChainHash is the genesis hash for the chain that all the gossip
196
        // messages in this store are aimed at.
197
        ChainHash chainhash.Hash
198

199
        // QueryConfig holds configuration values for SQL queries.
200
        QueryCfg *sqldb.QueryConfig
201
}
202

203
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
204
// storage backend.
205
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
206
        options ...StoreOptionModifier) (*SQLStore, error) {
×
207

×
208
        opts := DefaultOptions()
×
209
        for _, o := range options {
×
210
                o(opts)
×
211
        }
×
212

213
        if opts.NoMigration {
×
214
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
215
                        "supported for SQL stores")
×
216
        }
×
217

218
        s := &SQLStore{
×
219
                cfg:         cfg,
×
220
                db:          db,
×
221
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
222
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
223
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
224
        }
×
225

×
226
        s.chanScheduler = batch.NewTimeScheduler(
×
227
                db, &s.cacheMu, opts.BatchCommitInterval,
×
228
        )
×
229
        s.nodeScheduler = batch.NewTimeScheduler(
×
230
                db, nil, opts.BatchCommitInterval,
×
231
        )
×
232

×
233
        return s, nil
×
234
}
235

236
// AddNode adds a vertex/node to the graph database. If the node is not
237
// in the database from before, this will add a new, unconnected one to the
238
// graph. If it is present from before, this will update that node's
239
// information.
240
//
241
// NOTE: part of the Store interface.
242
func (s *SQLStore) AddNode(ctx context.Context,
243
        node *models.Node, opts ...batch.SchedulerOption) error {
×
244

×
245
        r := &batch.Request[SQLQueries]{
×
246
                Opts: batch.NewSchedulerOptions(opts...),
×
247
                Do: func(queries SQLQueries) error {
×
248
                        _, err := upsertNode(ctx, queries, node)
×
249

×
250
                        // It is possible that two of the same node
×
251
                        // announcements are both being processed in the same
×
252
                        // batch. This may case the UpsertNode conflict to
×
253
                        // be hit since we require at the db layer that the
×
254
                        // new last_update is greater than the existing
×
255
                        // last_update. We need to gracefully handle this here.
×
256
                        if errors.Is(err, sql.ErrNoRows) {
×
257
                                return nil
×
258
                        }
×
259

260
                        return err
×
261
                },
262
        }
263

264
        return s.nodeScheduler.Execute(ctx, r)
×
265
}
266

267
// FetchNode attempts to look up a target node by its identity public
268
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
269
// returned.
270
//
271
// NOTE: part of the Store interface.
272
func (s *SQLStore) FetchNode(ctx context.Context, v lnwire.GossipVersion,
273
        pubKey route.Vertex) (*models.Node, error) {
×
274

×
275
        var node *models.Node
×
276
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
277
                var err error
×
NEW
278
                _, node, err = getNodeByPubKey(
×
NEW
279
                        ctx, s.cfg.QueryCfg, db, v, pubKey,
×
NEW
280
                )
×
281

×
282
                return err
×
283
        }, sqldb.NoOpReset)
×
284
        if err != nil {
×
285
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
286
        }
×
287

288
        return node, nil
×
289
}
290

291
// HasV1Node determines if the graph has a vertex identified by the
292
// target node identity public key. If the node exists in the database, a
293
// timestamp of when the data for the node was lasted updated is returned along
294
// with a true boolean. Otherwise, an empty time.Time is returned with a false
295
// boolean.
296
//
297
// NOTE: part of the Store interface.
298
func (s *SQLStore) HasV1Node(ctx context.Context,
299
        pubKey [33]byte) (time.Time, bool, error) {
×
300

×
301
        var (
×
302
                exists     bool
×
303
                lastUpdate time.Time
×
304
        )
×
305
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
306
                dbNode, err := db.GetNodeByPubKey(
×
307
                        ctx, sqlc.GetNodeByPubKeyParams{
×
308
                                Version: int16(lnwire.GossipVersion1),
×
309
                                PubKey:  pubKey[:],
×
310
                        },
×
311
                )
×
312
                if errors.Is(err, sql.ErrNoRows) {
×
313
                        return nil
×
314
                } else if err != nil {
×
315
                        return fmt.Errorf("unable to fetch node: %w", err)
×
316
                }
×
317

318
                exists = true
×
319

×
320
                if dbNode.LastUpdate.Valid {
×
321
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
322
                }
×
323

324
                return nil
×
325
        }, sqldb.NoOpReset)
326
        if err != nil {
×
327
                return time.Time{}, false,
×
328
                        fmt.Errorf("unable to fetch node: %w", err)
×
329
        }
×
330

331
        return lastUpdate, exists, nil
×
332
}
333

334
// HasNode determines if the graph has a vertex identified by the
335
// target node identity public key.
336
//
337
// NOTE: part of the Store interface.
338
func (s *SQLStore) HasNode(ctx context.Context, v lnwire.GossipVersion,
NEW
339
        pubKey [33]byte) (bool, error) {
×
NEW
340

×
NEW
341
        var exists bool
×
NEW
342
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
343
                var err error
×
NEW
344
                exists, err = db.NodeExists(ctx, sqlc.NodeExistsParams{
×
NEW
345
                        Version: int16(v),
×
NEW
346
                        PubKey:  pubKey[:],
×
NEW
347
                })
×
NEW
348

×
NEW
349
                return err
×
NEW
350
        }, sqldb.NoOpReset)
×
NEW
351
        if err != nil {
×
NEW
352
                return false, fmt.Errorf("unable to check if node (%x) "+
×
NEW
353
                        "exists: %w", pubKey, err)
×
NEW
354
        }
×
355

NEW
356
        return exists, nil
×
357
}
358

359
// AddrsForNode returns all known addresses for the target node public key
360
// that the graph DB is aware of. The returned boolean indicates if the
361
// given node is unknown to the graph DB or not.
362
//
363
// NOTE: part of the Store interface.
364
func (s *SQLStore) AddrsForNode(ctx context.Context, v lnwire.GossipVersion,
365
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
366

×
367
        var (
×
368
                addresses []net.Addr
×
369
                known     bool
×
370
        )
×
371
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
372
                // First, check if the node exists and get its DB ID if it
×
373
                // does.
×
374
                dbID, err := db.GetNodeIDByPubKey(
×
375
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
376
                                Version: int16(v),
×
377
                                PubKey:  nodePub.SerializeCompressed(),
×
378
                        },
×
379
                )
×
380
                if errors.Is(err, sql.ErrNoRows) {
×
381
                        return nil
×
382
                }
×
383

384
                known = true
×
385

×
386
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
387
                if err != nil {
×
388
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
389
                                err)
×
390
                }
×
391

392
                return nil
×
393
        }, sqldb.NoOpReset)
394
        if err != nil {
×
395
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
396
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
397
        }
×
398

399
        return known, addresses, nil
×
400
}
401

402
// DeleteNode starts a new database transaction to remove a vertex/node
403
// from the database according to the node's public key.
404
//
405
// NOTE: part of the Store interface.
406
func (s *SQLStore) DeleteNode(ctx context.Context, v lnwire.GossipVersion,
407
        pubKey route.Vertex) error {
×
408

×
409
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
410
                res, err := db.DeleteNodeByPubKey(
×
411
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
NEW
412
                                Version: int16(v),
×
413
                                PubKey:  pubKey[:],
×
414
                        },
×
415
                )
×
416
                if err != nil {
×
417
                        return err
×
418
                }
×
419

420
                rows, err := res.RowsAffected()
×
421
                if err != nil {
×
422
                        return err
×
423
                }
×
424

425
                if rows == 0 {
×
426
                        return ErrGraphNodeNotFound
×
427
                } else if rows > 1 {
×
428
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
429
                }
×
430

431
                return err
×
432
        }, sqldb.NoOpReset)
433
        if err != nil {
×
434
                return fmt.Errorf("unable to delete node: %w", err)
×
435
        }
×
436

437
        return nil
×
438
}
439

440
// FetchNodeFeatures returns the features of the given node. If no features are
441
// known for the node, an empty feature vector is returned.
442
//
443
// NOTE: this is part of the graphdb.NodeTraverser interface.
444
func (s *SQLStore) FetchNodeFeatures(v lnwire.GossipVersion,
NEW
445
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
446

×
447
        ctx := context.TODO()
×
448

×
NEW
449
        return fetchNodeFeatures(ctx, s.db, v, nodePub)
×
450
}
×
451

452
// DisabledChannelIDs returns the channel ids of disabled channels.
453
// A channel is disabled when two of the associated ChanelEdgePolicies
454
// have their disabled bit on.
455
//
456
// NOTE: part of the Store interface.
457
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
458
        var (
×
459
                ctx     = context.TODO()
×
460
                chanIDs []uint64
×
461
        )
×
462
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
463
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
464
                if err != nil {
×
465
                        return fmt.Errorf("unable to fetch disabled "+
×
466
                                "channels: %w", err)
×
467
                }
×
468

469
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
470

×
471
                return nil
×
472
        }, sqldb.NoOpReset)
473
        if err != nil {
×
474
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
475
                        err)
×
476
        }
×
477

478
        return chanIDs, nil
×
479
}
480

481
// LookupAlias attempts to return the alias as advertised by the target node.
482
//
483
// NOTE: part of the Store interface.
484
func (s *SQLStore) LookupAlias(ctx context.Context, v lnwire.GossipVersion,
485
        pub *btcec.PublicKey) (string, error) {
×
486

×
487
        var alias string
×
488
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
489
                dbNode, err := db.GetNodeByPubKey(
×
490
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
491
                                Version: int16(v),
×
492
                                PubKey:  pub.SerializeCompressed(),
×
493
                        },
×
494
                )
×
495
                if errors.Is(err, sql.ErrNoRows) {
×
496
                        return ErrNodeAliasNotFound
×
497
                } else if err != nil {
×
498
                        return fmt.Errorf("unable to fetch node: %w", err)
×
499
                }
×
500

501
                if !dbNode.Alias.Valid {
×
502
                        return ErrNodeAliasNotFound
×
503
                }
×
504

505
                alias = dbNode.Alias.String
×
506

×
507
                return nil
×
508
        }, sqldb.NoOpReset)
509
        if err != nil {
×
510
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
511
        }
×
512

513
        return alias, nil
×
514
}
515

516
// SourceNode returns the source node of the graph. The source node is treated
517
// as the center node within a star-graph. This method may be used to kick off
518
// a path finding algorithm in order to explore the reachability of another
519
// node based off the source node.
520
//
521
// NOTE: part of the Store interface.
522
func (s *SQLStore) SourceNode(ctx context.Context,
NEW
523
        v lnwire.GossipVersion) (*models.Node, error) {
×
524

×
525
        var node *models.Node
×
526
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
527
                _, nodePub, err := s.getSourceNode(ctx, db, v)
×
528
                if err != nil {
×
NEW
529
                        return fmt.Errorf("unable to fetch source node: %w",
×
530
                                err)
×
531
                }
×
532

NEW
533
                _, node, err = getNodeByPubKey(
×
NEW
534
                        ctx, s.cfg.QueryCfg, db, v, nodePub,
×
NEW
535
                )
×
536

×
537
                return err
×
538
        }, sqldb.NoOpReset)
539
        if err != nil {
×
540
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
541
        }
×
542

543
        return node, nil
×
544
}
545

546
// SetSourceNode sets the source node within the graph database. The source
547
// node is to be used as the center of a star-graph within path finding
548
// algorithms.
549
//
550
// NOTE: part of the Store interface.
551
func (s *SQLStore) SetSourceNode(ctx context.Context,
552
        node *models.Node) error {
×
553

×
554
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
555
                // For the source node, we use a less strict upsert that allows
×
556
                // updates even when the timestamp hasn't changed. This handles
×
557
                // the race condition where multiple goroutines (e.g.,
×
558
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
559
                // same old timestamp, independently increment it, and try to
×
560
                // write concurrently. We want all parameter changes to persist,
×
561
                // even if timestamps collide.
×
562
                id, err := upsertSourceNode(ctx, db, node)
×
563
                if err != nil {
×
564
                        return fmt.Errorf("unable to upsert source node: %w",
×
565
                                err)
×
566
                }
×
567

568
                // Make sure that if a source node for this version is already
569
                // set, then the ID is the same as the one we are about to set.
570
                dbSourceNodeID, _, err := s.getSourceNode(
×
NEW
571
                        ctx, db, node.Version,
×
572
                )
×
573
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
574
                        return fmt.Errorf("unable to fetch source node: %w",
×
575
                                err)
×
576
                } else if err == nil {
×
577
                        if dbSourceNodeID != id {
×
578
                                return fmt.Errorf("v1 source node already "+
×
579
                                        "set to a different node: %d vs %d",
×
580
                                        dbSourceNodeID, id)
×
581
                        }
×
582

583
                        return nil
×
584
                }
585

586
                return db.AddSourceNode(ctx, id)
×
587
        }, sqldb.NoOpReset)
588
}
589

590
// NodeUpdatesInHorizon returns all the known lightning node which have an
591
// update timestamp within the passed range. This method can be used by two
592
// nodes to quickly determine if they have the same set of up to date node
593
// announcements.
594
//
595
// NOTE: This is part of the Store interface.
596
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
597
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
598

×
599
        cfg := defaultIteratorConfig()
×
600
        for _, opt := range opts {
×
601
                opt(cfg)
×
602
        }
×
603

604
        return func(yield func(*models.Node, error) bool) {
×
605
                var (
×
606
                        ctx            = context.TODO()
×
607
                        lastUpdateTime sql.NullInt64
×
608
                        lastPubKey     = make([]byte, 33)
×
609
                        hasMore        = true
×
610
                )
×
611

×
612
                // Each iteration, we'll read a batch amount of nodes, yield
×
613
                // them, then decide is we have more or not.
×
614
                for hasMore {
×
615
                        var batch []*models.Node
×
616

×
617
                        //nolint:ll
×
618
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
619
                                //nolint:ll
×
620
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
621
                                        StartTime: sqldb.SQLInt64(
×
622
                                                startTime.Unix(),
×
623
                                        ),
×
624
                                        EndTime: sqldb.SQLInt64(
×
625
                                                endTime.Unix(),
×
626
                                        ),
×
627
                                        LastUpdate: lastUpdateTime,
×
628
                                        LastPubKey: lastPubKey,
×
629
                                        OnlyPublic: sql.NullBool{
×
630
                                                Bool:  cfg.iterPublicNodes,
×
631
                                                Valid: true,
×
632
                                        },
×
633
                                        MaxResults: sqldb.SQLInt32(
×
634
                                                cfg.nodeUpdateIterBatchSize,
×
635
                                        ),
×
636
                                }
×
637
                                rows, err := db.GetNodesByLastUpdateRange(
×
638
                                        ctx, params,
×
639
                                )
×
640
                                if err != nil {
×
641
                                        return err
×
642
                                }
×
643

644
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
645

×
646
                                err = forEachNodeInBatch(
×
647
                                        ctx, s.cfg.QueryCfg, db, rows,
×
648
                                        func(_ int64, node *models.Node) error {
×
649
                                                batch = append(batch, node)
×
650

×
651
                                                // Update pagination cursors
×
652
                                                // based on the last processed
×
653
                                                // node.
×
654
                                                lastUpdateTime = sql.NullInt64{
×
655
                                                        Int64: node.LastUpdate.
×
656
                                                                Unix(),
×
657
                                                        Valid: true,
×
658
                                                }
×
659
                                                lastPubKey = node.PubKeyBytes[:]
×
660

×
661
                                                return nil
×
662
                                        },
×
663
                                )
664
                                if err != nil {
×
665
                                        return fmt.Errorf("unable to build "+
×
666
                                                "nodes: %w", err)
×
667
                                }
×
668

669
                                return nil
×
670
                        }, func() {
×
671
                                batch = []*models.Node{}
×
672
                        })
×
673

674
                        if err != nil {
×
675
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
676
                                        "error: %v", err)
×
677

×
678
                                yield(&models.Node{}, err)
×
679

×
680
                                return
×
681
                        }
×
682

683
                        for _, node := range batch {
×
684
                                if !yield(node, nil) {
×
685
                                        return
×
686
                                }
×
687
                        }
688

689
                        // If the batch didn't yield anything, then we're done.
690
                        if len(batch) == 0 {
×
691
                                break
×
692
                        }
693
                }
694
        }
695
}
696

697
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
698
// undirected edge from the two target nodes are created. The information stored
699
// denotes the static attributes of the channel, such as the channelID, the keys
700
// involved in creation of the channel, and the set of features that the channel
701
// supports. The chanPoint and chanID are used to uniquely identify the edge
702
// globally within the database.
703
//
704
// NOTE: part of the Store interface.
705
func (s *SQLStore) AddChannelEdge(ctx context.Context,
706
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
707

×
NEW
708
        if !isKnownGossipVersion(edge.Version) {
×
NEW
709
                return fmt.Errorf("unsupported gossip version: %d",
×
NEW
710
                        edge.Version)
×
NEW
711
        }
×
712

713
        var alreadyExists bool
×
714
        r := &batch.Request[SQLQueries]{
×
715
                Opts: batch.NewSchedulerOptions(opts...),
×
716
                Reset: func() {
×
717
                        alreadyExists = false
×
718
                },
×
719
                Do: func(tx SQLQueries) error {
×
720
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
721

×
722
                        // Make sure that the channel doesn't already exist. We
×
723
                        // do this explicitly instead of relying on catching a
×
724
                        // unique constraint error because relying on SQL to
×
725
                        // throw that error would abort the entire batch of
×
726
                        // transactions.
×
727
                        _, err := tx.GetChannelBySCID(
×
728
                                ctx, sqlc.GetChannelBySCIDParams{
×
729
                                        Scid:    chanIDB,
×
NEW
730
                                        Version: int16(edge.Version),
×
731
                                },
×
732
                        )
×
733
                        if err == nil {
×
734
                                alreadyExists = true
×
735
                                return nil
×
736
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
737
                                return fmt.Errorf("unable to fetch channel: %w",
×
738
                                        err)
×
739
                        }
×
740

741
                        return insertChannel(ctx, tx, edge)
×
742
                },
743
                OnCommit: func(err error) error {
×
744
                        switch {
×
745
                        case err != nil:
×
746
                                return err
×
747
                        case alreadyExists:
×
748
                                return ErrEdgeAlreadyExist
×
749
                        default:
×
NEW
750
                                s.rejectCache.remove(
×
NEW
751
                                        edge.Version, edge.ChannelID,
×
NEW
752
                                )
×
753
                                s.chanCache.remove(edge.ChannelID)
×
NEW
754

×
UNCOV
755
                                return nil
×
756
                        }
757
                },
758
        }
759

760
        return s.chanScheduler.Execute(ctx, r)
×
761
}
762

763
// HighestChanID returns the "highest" known channel ID in the channel graph.
764
// This represents the "newest" channel from the PoV of the chain. This method
765
// can be used by peers to quickly determine if their graphs are in sync.
766
//
767
// NOTE: This is part of the Store interface.
768
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
769
        var highestChanID uint64
×
770
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
771
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
772
                if errors.Is(err, sql.ErrNoRows) {
×
773
                        return nil
×
774
                } else if err != nil {
×
775
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
776
                                err)
×
777
                }
×
778

779
                highestChanID = byteOrder.Uint64(chanID)
×
780

×
781
                return nil
×
782
        }, sqldb.NoOpReset)
783
        if err != nil {
×
784
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
785
        }
×
786

787
        return highestChanID, nil
×
788
}
789

790
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
791
// within the database for the referenced channel. The `flags` attribute within
792
// the ChannelEdgePolicy determines which of the directed edges are being
793
// updated. If the flag is 1, then the first node's information is being
794
// updated, otherwise it's the second node's information. The node ordering is
795
// determined by the lexicographical ordering of the identity public keys of the
796
// nodes on either side of the channel.
797
//
798
// NOTE: part of the Store interface.
799
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
800
        edge *models.ChannelEdgePolicy,
801
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
802

×
803
        var (
×
804
                isUpdate1    bool
×
805
                edgeNotFound bool
×
806
                from, to     route.Vertex
×
807
        )
×
808

×
809
        r := &batch.Request[SQLQueries]{
×
810
                Opts: batch.NewSchedulerOptions(opts...),
×
811
                Reset: func() {
×
812
                        isUpdate1 = false
×
813
                        edgeNotFound = false
×
814
                },
×
815
                Do: func(tx SQLQueries) error {
×
816
                        var err error
×
817
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
818
                                ctx, tx, edge,
×
819
                        )
×
820
                        // It is possible that two of the same policy
×
821
                        // announcements are both being processed in the same
×
822
                        // batch. This may case the UpsertEdgePolicy conflict to
×
823
                        // be hit since we require at the db layer that the
×
824
                        // new last_update is greater than the existing
×
825
                        // last_update. We need to gracefully handle this here.
×
826
                        if errors.Is(err, sql.ErrNoRows) {
×
827
                                return nil
×
828
                        } else if err != nil {
×
829
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
830
                        }
×
831

832
                        // Silence ErrEdgeNotFound so that the batch can
833
                        // succeed, but propagate the error via local state.
834
                        if errors.Is(err, ErrEdgeNotFound) {
×
835
                                edgeNotFound = true
×
836
                                return nil
×
837
                        }
×
838

839
                        return err
×
840
                },
841
                OnCommit: func(err error) error {
×
842
                        switch {
×
843
                        case err != nil:
×
844
                                return err
×
845
                        case edgeNotFound:
×
846
                                return ErrEdgeNotFound
×
847
                        default:
×
848
                                s.updateEdgeCache(edge, isUpdate1)
×
849
                                return nil
×
850
                        }
851
                },
852
        }
853

854
        err := s.chanScheduler.Execute(ctx, r)
×
855

×
856
        return from, to, err
×
857
}
858

859
// updateEdgeCache updates our reject and channel caches with the new
860
// edge policy information.
861
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
862
        isUpdate1 bool) {
×
863

×
864
        // If an entry for this channel is found in reject cache, we'll modify
×
865
        // the entry with the updated timestamp for the direction that was just
×
866
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
867
        // during the next query for this edge.
×
NEW
868
        if entry, ok := s.rejectCache.get(e.Version, e.ChannelID); ok {
×
NEW
869
                switch e.Version {
×
NEW
870
                case lnwire.GossipVersion1:
×
NEW
871
                        updateRejectCacheEntryV1(
×
NEW
872
                                &entry, isUpdate1, e.LastUpdate,
×
NEW
873
                        )
×
NEW
874
                case lnwire.GossipVersion2:
×
NEW
875
                        updateRejectCacheEntryV2(
×
NEW
876
                                &entry, isUpdate1, e.LastBlockHeight,
×
NEW
877
                        )
×
878
                }
NEW
879
                s.rejectCache.insert(e.Version, e.ChannelID, entry)
×
880
        }
881

882
        // If an entry for this channel is found in channel cache, we'll modify
883
        // the entry with the updated policy for the direction that was just
884
        // written. If the edge doesn't exist, we'll defer loading the info and
885
        // policies and lazily read from disk during the next query.
886
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
887
                if isUpdate1 {
×
888
                        channel.Policy1 = e
×
889
                } else {
×
890
                        channel.Policy2 = e
×
891
                }
×
892
                s.chanCache.insert(e.ChannelID, channel)
×
893
        }
894
}
895

896
// ForEachSourceNodeChannel iterates through all channels of the source node,
897
// executing the passed callback on each. The call-back is provided with the
898
// channel's outpoint, whether we have a policy for the channel and the channel
899
// peer's node information.
900
//
901
// NOTE: part of the Store interface.
902
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
903
        cb func(chanPoint wire.OutPoint, havePolicy bool,
904
                otherNode *models.Node) error, reset func()) error {
×
905

×
906
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
907
                nodeID, nodePub, err := s.getSourceNode(
×
908
                        ctx, db, lnwire.GossipVersion1,
×
909
                )
×
910
                if err != nil {
×
911
                        return fmt.Errorf("unable to fetch source node: %w",
×
912
                                err)
×
913
                }
×
914

915
                return forEachNodeChannel(
×
NEW
916
                        ctx, db, s.cfg, lnwire.GossipVersion1, nodeID,
×
917
                        func(info *models.ChannelEdgeInfo,
×
918
                                outPolicy *models.ChannelEdgePolicy,
×
919
                                _ *models.ChannelEdgePolicy) error {
×
920

×
921
                                // Fetch the other node.
×
922
                                var (
×
923
                                        otherNodePub [33]byte
×
924
                                        node1        = info.NodeKey1Bytes
×
925
                                        node2        = info.NodeKey2Bytes
×
926
                                )
×
927
                                switch {
×
928
                                case bytes.Equal(node1[:], nodePub[:]):
×
929
                                        otherNodePub = node2
×
930
                                case bytes.Equal(node2[:], nodePub[:]):
×
931
                                        otherNodePub = node1
×
932
                                default:
×
933
                                        return fmt.Errorf("node not " +
×
934
                                                "participating in this channel")
×
935
                                }
936

937
                                _, otherNode, err := getNodeByPubKey(
×
NEW
938
                                        ctx, s.cfg.QueryCfg, db,
×
NEW
939
                                        lnwire.GossipVersion1, otherNodePub,
×
940
                                )
×
941
                                if err != nil {
×
942
                                        return fmt.Errorf("unable to fetch "+
×
943
                                                "other node(%x): %w",
×
944
                                                otherNodePub, err)
×
945
                                }
×
946

947
                                return cb(
×
948
                                        info.ChannelPoint, outPolicy != nil,
×
949
                                        otherNode,
×
950
                                )
×
951
                        },
952
                )
953
        }, reset)
954
}
955

956
// ForEachNode iterates through all the stored vertices/nodes in the graph,
957
// executing the passed callback with each node encountered. If the callback
958
// returns an error, then the transaction is aborted and the iteration stops
959
// early.
960
//
961
// NOTE: part of the Store interface.
962
func (s *SQLStore) ForEachNode(ctx context.Context,
963
        cb func(node *models.Node) error, reset func()) error {
×
964

×
965
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
966
                return forEachNodePaginated(
×
967
                        ctx, s.cfg.QueryCfg, db,
×
968
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
969
                                node *models.Node) error {
×
970

×
971
                                return cb(node)
×
972
                        },
×
973
                )
974
        }, reset)
975
}
976

977
// ForEachNodeDirectedChannel iterates through all channels of a given node,
978
// executing the passed callback on the directed edge representing the channel
979
// and its incoming policy. If the callback returns an error, then the iteration
980
// is halted with the error propagated back up to the caller.
981
//
982
// Unknown policies are passed into the callback as nil values.
983
//
984
// NOTE: this is part of the graphdb.NodeTraverser interface.
985
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
986
        cb func(channel *DirectedChannel) error, reset func()) error {
×
987

×
988
        var ctx = context.TODO()
×
989

×
990
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
991
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
992
        }, reset)
×
993
}
994

995
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
996
// graph, executing the passed callback with each node encountered. If the
997
// callback returns an error, then the transaction is aborted and the iteration
998
// stops early.
999
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
1000
        cb func(route.Vertex, *lnwire.FeatureVector) error,
1001
        reset func()) error {
×
1002

×
1003
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1004
                return forEachNodeCacheable(
×
1005
                        ctx, s.cfg.QueryCfg, db,
×
1006
                        func(_ int64, nodePub route.Vertex,
×
1007
                                features *lnwire.FeatureVector) error {
×
1008

×
1009
                                return cb(nodePub, features)
×
1010
                        },
×
1011
                )
1012
        }, reset)
1013
        if err != nil {
×
1014
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
1015
        }
×
1016

1017
        return nil
×
1018
}
1019

1020
// ForEachNodeChannel iterates through all channels of the given node,
1021
// executing the passed callback with an edge info structure and the policies
1022
// of each end of the channel. The first edge policy is the outgoing edge *to*
1023
// the connecting node, while the second is the incoming edge *from* the
1024
// connecting node. If the callback returns an error, then the iteration is
1025
// halted with the error propagated back up to the caller.
1026
//
1027
// Unknown policies are passed into the callback as nil values.
1028
//
1029
// NOTE: part of the Store interface.
1030
func (s *SQLStore) ForEachNodeChannel(ctx context.Context,
1031
        v lnwire.GossipVersion, nodePub route.Vertex,
1032
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1033
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1034

×
1035
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1036
                dbNode, err := db.GetNodeByPubKey(
×
1037
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
1038
                                Version: int16(v),
×
1039
                                PubKey:  nodePub[:],
×
1040
                        },
×
1041
                )
×
1042
                if errors.Is(err, sql.ErrNoRows) {
×
1043
                        return nil
×
1044
                } else if err != nil {
×
1045
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1046
                }
×
1047

NEW
1048
                return forEachNodeChannel(ctx, db, s.cfg, v, dbNode.ID, cb)
×
1049
        }, reset)
1050
}
1051

1052
// extractMaxUpdateTime returns the maximum of the two policy update times.
1053
// This is used for pagination cursor tracking.
1054
func extractMaxUpdateTime(
1055
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1056

×
1057
        switch {
×
1058
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1059
                return max(row.Policy1LastUpdate.Int64,
×
1060
                        row.Policy2LastUpdate.Int64)
×
1061
        case row.Policy1LastUpdate.Valid:
×
1062
                return row.Policy1LastUpdate.Int64
×
1063
        case row.Policy2LastUpdate.Valid:
×
1064
                return row.Policy2LastUpdate.Int64
×
1065
        default:
×
1066
                return 0
×
1067
        }
1068
}
1069

1070
// buildChannelFromRow constructs a ChannelEdge from a database row.
1071
// This includes building the nodes, channel info, and policies.
1072
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1073
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1074

×
1075
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1076
        if err != nil {
×
1077
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1078
                        err)
×
1079
        }
×
1080

1081
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1082
        if err != nil {
×
1083
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1084
                        err)
×
1085
        }
×
1086

1087
        channel, err := getAndBuildEdgeInfo(
×
1088
                ctx, s.cfg, db,
×
1089
                row.GraphChannel, node1.PubKeyBytes,
×
1090
                node2.PubKeyBytes,
×
1091
        )
×
1092
        if err != nil {
×
1093
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1094
                        "channel info: %w", err)
×
1095
        }
×
1096

1097
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1098
        if err != nil {
×
1099
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1100
                        "channel policies: %w", err)
×
1101
        }
×
1102

1103
        p1, p2, err := getAndBuildChanPolicies(
×
1104
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1105
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1106
        )
×
1107
        if err != nil {
×
1108
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1109
                        "channel policies: %w", err)
×
1110
        }
×
1111

1112
        return ChannelEdge{
×
1113
                Info:    channel,
×
1114
                Policy1: p1,
×
1115
                Policy2: p2,
×
1116
                Node1:   node1,
×
1117
                Node2:   node2,
×
1118
        }, nil
×
1119
}
1120

1121
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1122
// This method acquires the cache lock only once for the entire batch.
1123
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1124
        if len(edgesToCache) == 0 {
×
1125
                return
×
1126
        }
×
1127

1128
        s.cacheMu.Lock()
×
1129
        defer s.cacheMu.Unlock()
×
1130

×
1131
        for chanID, edge := range edgesToCache {
×
1132
                s.chanCache.insert(chanID, edge)
×
1133
        }
×
1134
}
1135

1136
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1137
// one edge that has an update timestamp within the specified horizon.
1138
//
1139
// Iterator Lifecycle:
1140
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1141
// 2. Query batch of channels with policies in time range
1142
// 3. For each channel: check if seen, check cache, or build from DB
1143
// 4. Yield channels to caller
1144
// 5. Update cache after successful batch
1145
// 6. Repeat with updated pagination cursor until no more results
1146
//
1147
// NOTE: This is part of the Store interface.
1148
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1149
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1150

×
1151
        // Apply options.
×
1152
        cfg := defaultIteratorConfig()
×
1153
        for _, opt := range opts {
×
1154
                opt(cfg)
×
1155
        }
×
1156

1157
        return func(yield func(ChannelEdge, error) bool) {
×
1158
                var (
×
1159
                        ctx            = context.TODO()
×
1160
                        edgesSeen      = make(map[uint64]struct{})
×
1161
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1162
                        hits           int
×
1163
                        total          int
×
1164
                        lastUpdateTime sql.NullInt64
×
1165
                        lastID         sql.NullInt64
×
1166
                        hasMore        = true
×
1167
                )
×
1168

×
1169
                // Each iteration, we'll read a batch amount of channel updates
×
1170
                // (consulting the cache along the way), yield them, then loop
×
1171
                // back to decide if we have any more updates to read out.
×
1172
                for hasMore {
×
1173
                        var batch []ChannelEdge
×
1174

×
1175
                        // Acquire read lock before starting transaction to
×
1176
                        // ensure consistent lock ordering (cacheMu -> DB) and
×
1177
                        // prevent deadlock with write operations.
×
1178
                        s.cacheMu.RLock()
×
1179

×
1180
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1181
                                func(db SQLQueries) error {
×
1182
                                        //nolint:ll
×
1183
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1184
                                                Version: int16(lnwire.GossipVersion1),
×
1185
                                                StartTime: sqldb.SQLInt64(
×
1186
                                                        startTime.Unix(),
×
1187
                                                ),
×
1188
                                                EndTime: sqldb.SQLInt64(
×
1189
                                                        endTime.Unix(),
×
1190
                                                ),
×
1191
                                                LastUpdateTime: lastUpdateTime,
×
1192
                                                LastID:         lastID,
×
1193
                                                MaxResults: sql.NullInt32{
×
1194
                                                        Int32: int32(
×
1195
                                                                cfg.chanUpdateIterBatchSize,
×
1196
                                                        ),
×
1197
                                                        Valid: true,
×
1198
                                                },
×
1199
                                        }
×
1200
                                        //nolint:ll
×
1201
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1202
                                                ctx, params,
×
1203
                                        )
×
1204
                                        if err != nil {
×
1205
                                                return err
×
1206
                                        }
×
1207

1208
                                        //nolint:ll
1209
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1210

×
1211
                                        //nolint:ll
×
1212
                                        for _, row := range rows {
×
1213
                                                lastUpdateTime = sql.NullInt64{
×
1214
                                                        Int64: extractMaxUpdateTime(row),
×
1215
                                                        Valid: true,
×
1216
                                                }
×
1217
                                                lastID = sql.NullInt64{
×
1218
                                                        Int64: row.GraphChannel.ID,
×
1219
                                                        Valid: true,
×
1220
                                                }
×
1221

×
1222
                                                // Skip if we've already
×
1223
                                                // processed this channel.
×
1224
                                                chanIDInt := byteOrder.Uint64(
×
1225
                                                        row.GraphChannel.Scid,
×
1226
                                                )
×
1227
                                                _, ok := edgesSeen[chanIDInt]
×
1228
                                                if ok {
×
1229
                                                        continue
×
1230
                                                }
1231

1232
                                                // Check cache (we already hold
1233
                                                // shared read lock).
1234
                                                channel, ok := s.chanCache.get(
×
1235
                                                        chanIDInt,
×
1236
                                                )
×
1237
                                                if ok {
×
1238
                                                        hits++
×
1239
                                                        total++
×
1240
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1241
                                                        batch = append(batch, channel)
×
1242

×
1243
                                                        continue
×
1244
                                                }
1245

1246
                                                chanEdge, err := s.buildChannelFromRow(
×
1247
                                                        ctx, db, row,
×
1248
                                                )
×
1249
                                                if err != nil {
×
1250
                                                        return err
×
1251
                                                }
×
1252

1253
                                                edgesSeen[chanIDInt] = struct{}{}
×
1254
                                                edgesToCache[chanIDInt] = chanEdge
×
1255

×
1256
                                                batch = append(batch, chanEdge)
×
1257

×
1258
                                                total++
×
1259
                                        }
1260

1261
                                        return nil
×
1262
                                }, func() {
×
1263
                                        batch = nil
×
1264
                                        edgesSeen = make(map[uint64]struct{})
×
1265
                                        edgesToCache = make(
×
1266
                                                map[uint64]ChannelEdge,
×
1267
                                        )
×
1268
                                })
×
1269

1270
                        // Release read lock after transaction completes.
1271
                        s.cacheMu.RUnlock()
×
1272

×
1273
                        if err != nil {
×
1274
                                log.Errorf("ChanUpdatesInHorizon "+
×
1275
                                        "batch error: %v", err)
×
1276

×
1277
                                yield(ChannelEdge{}, err)
×
1278

×
1279
                                return
×
1280
                        }
×
1281

1282
                        for _, edge := range batch {
×
1283
                                if !yield(edge, nil) {
×
1284
                                        return
×
1285
                                }
×
1286
                        }
1287

1288
                        // Update cache after successful batch yield, setting
1289
                        // the cache lock only once for the entire batch.
1290
                        s.updateChanCacheBatch(edgesToCache)
×
1291
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1292

×
1293
                        // If the batch didn't yield anything, then we're done.
×
1294
                        if len(batch) == 0 {
×
1295
                                break
×
1296
                        }
1297
                }
1298

1299
                if total > 0 {
×
1300
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1301
                                "%.2f (%d/%d)",
×
1302
                                float64(hits)*100/float64(total), hits, total)
×
1303
                } else {
×
1304
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1305
                                "in horizon (%s, %s)", startTime, endTime)
×
1306
                }
×
1307
        }
1308
}
1309

1310
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1311
// data to the call-back. If withAddrs is true, then the call-back will also be
1312
// provided with the addresses associated with the node. The address retrieval
1313
// result in an additional round-trip to the database, so it should only be used
1314
// if the addresses are actually needed.
1315
//
1316
// NOTE: part of the Store interface.
1317
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1318
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1319
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1320

×
1321
        type nodeCachedBatchData struct {
×
1322
                features      map[int64][]int
×
1323
                addrs         map[int64][]nodeAddress
×
1324
                chanBatchData *batchChannelData
×
1325
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1326
        }
×
1327

×
1328
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1329
                // pageQueryFunc is used to query the next page of nodes.
×
1330
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1331
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1332

×
1333
                        return db.ListNodeIDsAndPubKeys(
×
1334
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1335
                                        Version: int16(lnwire.GossipVersion1),
×
1336
                                        ID:      lastID,
×
1337
                                        Limit:   limit,
×
1338
                                },
×
1339
                        )
×
1340
                }
×
1341

1342
                // batchDataFunc is then used to batch load the data required
1343
                // for each page of nodes.
1344
                batchDataFunc := func(ctx context.Context,
×
1345
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1346

×
1347
                        // Batch load node features.
×
1348
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1349
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1350
                        )
×
1351
                        if err != nil {
×
1352
                                return nil, fmt.Errorf("unable to batch load "+
×
1353
                                        "node features: %w", err)
×
1354
                        }
×
1355

1356
                        // Maybe fetch the node's addresses if requested.
1357
                        var nodeAddrs map[int64][]nodeAddress
×
1358
                        if withAddrs {
×
1359
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1360
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1361
                                )
×
1362
                                if err != nil {
×
1363
                                        return nil, fmt.Errorf("unable to "+
×
1364
                                                "batch load node "+
×
1365
                                                "addresses: %w", err)
×
1366
                                }
×
1367
                        }
1368

1369
                        // Batch load ALL unique channels for ALL nodes in this
1370
                        // page.
1371
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1372
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1373
                                        Version:  int16(lnwire.GossipVersion1),
×
1374
                                        Node1Ids: nodeIDs,
×
1375
                                        Node2Ids: nodeIDs,
×
1376
                                },
×
1377
                        )
×
1378
                        if err != nil {
×
1379
                                return nil, fmt.Errorf("unable to batch "+
×
1380
                                        "fetch channels for nodes: %w", err)
×
1381
                        }
×
1382

1383
                        // Deduplicate channels and collect IDs.
1384
                        var (
×
1385
                                allChannelIDs []int64
×
1386
                                allPolicyIDs  []int64
×
1387
                        )
×
1388
                        uniqueChannels := make(
×
1389
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1390
                        )
×
1391

×
1392
                        for _, channel := range allChannels {
×
1393
                                channelID := channel.GraphChannel.ID
×
1394

×
1395
                                // Only process each unique channel once.
×
1396
                                _, exists := uniqueChannels[channelID]
×
1397
                                if exists {
×
1398
                                        continue
×
1399
                                }
1400

1401
                                uniqueChannels[channelID] = channel
×
1402
                                allChannelIDs = append(allChannelIDs, channelID)
×
1403

×
1404
                                if channel.Policy1ID.Valid {
×
1405
                                        allPolicyIDs = append(
×
1406
                                                allPolicyIDs,
×
1407
                                                channel.Policy1ID.Int64,
×
1408
                                        )
×
1409
                                }
×
1410
                                if channel.Policy2ID.Valid {
×
1411
                                        allPolicyIDs = append(
×
1412
                                                allPolicyIDs,
×
1413
                                                channel.Policy2ID.Int64,
×
1414
                                        )
×
1415
                                }
×
1416
                        }
1417

1418
                        // Batch load channel data for all unique channels.
1419
                        channelBatchData, err := batchLoadChannelData(
×
1420
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1421
                                allPolicyIDs,
×
1422
                        )
×
1423
                        if err != nil {
×
1424
                                return nil, fmt.Errorf("unable to batch "+
×
1425
                                        "load channel data: %w", err)
×
1426
                        }
×
1427

1428
                        // Create map of node ID to channels that involve this
1429
                        // node.
1430
                        nodeIDSet := make(map[int64]bool)
×
1431
                        for _, nodeID := range nodeIDs {
×
1432
                                nodeIDSet[nodeID] = true
×
1433
                        }
×
1434

1435
                        nodeChannelMap := make(
×
1436
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1437
                        )
×
1438
                        for _, channel := range uniqueChannels {
×
1439
                                // Add channel to both nodes if they're in our
×
1440
                                // current page.
×
1441
                                node1 := channel.GraphChannel.NodeID1
×
1442
                                if nodeIDSet[node1] {
×
1443
                                        nodeChannelMap[node1] = append(
×
1444
                                                nodeChannelMap[node1], channel,
×
1445
                                        )
×
1446
                                }
×
1447
                                node2 := channel.GraphChannel.NodeID2
×
1448
                                if nodeIDSet[node2] {
×
1449
                                        nodeChannelMap[node2] = append(
×
1450
                                                nodeChannelMap[node2], channel,
×
1451
                                        )
×
1452
                                }
×
1453
                        }
1454

1455
                        return &nodeCachedBatchData{
×
1456
                                features:      nodeFeatures,
×
1457
                                addrs:         nodeAddrs,
×
1458
                                chanBatchData: channelBatchData,
×
1459
                                chanMap:       nodeChannelMap,
×
1460
                        }, nil
×
1461
                }
1462

1463
                // processItem is used to process each node in the current page.
1464
                processItem := func(ctx context.Context,
×
1465
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1466
                        batchData *nodeCachedBatchData) error {
×
1467

×
1468
                        // Build feature vector for this node.
×
1469
                        fv := lnwire.EmptyFeatureVector()
×
1470
                        features, exists := batchData.features[nodeData.ID]
×
1471
                        if exists {
×
1472
                                for _, bit := range features {
×
1473
                                        fv.Set(lnwire.FeatureBit(bit))
×
1474
                                }
×
1475
                        }
1476

1477
                        var nodePub route.Vertex
×
1478
                        copy(nodePub[:], nodeData.PubKey)
×
1479

×
1480
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1481

×
1482
                        toNodeCallback := func() route.Vertex {
×
1483
                                return nodePub
×
1484
                        }
×
1485

1486
                        // Build cached channels map for this node.
1487
                        channels := make(map[uint64]*DirectedChannel)
×
1488
                        for _, channelRow := range nodeChannels {
×
1489
                                directedChan, err := buildDirectedChannel(
×
1490
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1491
                                        channelRow, batchData.chanBatchData, fv,
×
1492
                                        toNodeCallback,
×
1493
                                )
×
1494
                                if err != nil {
×
1495
                                        return err
×
1496
                                }
×
1497

1498
                                channels[directedChan.ChannelID] = directedChan
×
1499
                        }
1500

1501
                        addrs, err := buildNodeAddresses(
×
1502
                                batchData.addrs[nodeData.ID],
×
1503
                        )
×
1504
                        if err != nil {
×
1505
                                return fmt.Errorf("unable to build node "+
×
1506
                                        "addresses: %w", err)
×
1507
                        }
×
1508

1509
                        return cb(ctx, nodePub, addrs, channels)
×
1510
                }
1511

1512
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1513
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1514
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1515
                                return node.ID
×
1516
                        },
×
1517
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1518
                                error) {
×
1519

×
1520
                                return node.ID, nil
×
1521
                        },
×
1522
                        batchDataFunc, processItem,
1523
                )
1524
        }, reset)
1525
}
1526

1527
// ForEachChannelCacheable iterates through all the channel edges stored
1528
// within the graph and invokes the passed callback for each edge. The
1529
// callback takes two edges as since this is a directed graph, both the
1530
// in/out edges are visited. If the callback returns an error, then the
1531
// transaction is aborted and the iteration stops early.
1532
//
1533
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1534
// pointer for that particular channel edge routing policy will be
1535
// passed into the callback.
1536
//
1537
// NOTE: this method is like ForEachChannel but fetches only the data
1538
// required for the graph cache.
1539
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1540
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1541
        reset func()) error {
×
1542

×
1543
        ctx := context.TODO()
×
1544

×
1545
        handleChannel := func(_ context.Context,
×
1546
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1547

×
1548
                node1, node2, err := buildNodeVertices(
×
1549
                        row.Node1Pubkey, row.Node2Pubkey,
×
1550
                )
×
1551
                if err != nil {
×
1552
                        return err
×
1553
                }
×
1554

1555
                edge := buildCacheableChannelInfo(
×
1556
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1557
                )
×
1558

×
1559
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1560
                if err != nil {
×
1561
                        return err
×
1562
                }
×
1563

1564
                pol1, pol2, err := buildCachedChanPolicies(
×
1565
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1566
                )
×
1567
                if err != nil {
×
1568
                        return err
×
1569
                }
×
1570

1571
                return cb(edge, pol1, pol2)
×
1572
        }
1573

1574
        extractCursor := func(
×
1575
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1576

×
1577
                return row.ID
×
1578
        }
×
1579

1580
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1581
                //nolint:ll
×
1582
                queryFunc := func(ctx context.Context, lastID int64,
×
1583
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1584
                        error) {
×
1585

×
1586
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1587
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1588
                                        Version: int16(lnwire.GossipVersion1),
×
1589
                                        ID:      lastID,
×
1590
                                        Limit:   limit,
×
1591
                                },
×
1592
                        )
×
1593
                }
×
1594

1595
                return sqldb.ExecutePaginatedQuery(
×
1596
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1597
                        extractCursor, handleChannel,
×
1598
                )
×
1599
        }, reset)
1600
}
1601

1602
// ForEachChannel iterates through all the channel edges stored within the
1603
// graph and invokes the passed callback for each edge. The callback takes two
1604
// edges as since this is a directed graph, both the in/out edges are visited.
1605
// If the callback returns an error, then the transaction is aborted and the
1606
// iteration stops early.
1607
//
1608
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1609
// for that particular channel edge routing policy will be passed into the
1610
// callback.
1611
//
1612
// NOTE: part of the Store interface.
1613
func (s *SQLStore) ForEachChannel(ctx context.Context,
1614
        v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo,
1615
                *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error,
NEW
1616
        reset func()) error {
×
NEW
1617

×
NEW
1618
        if !isKnownGossipVersion(v) {
×
NEW
1619
                return fmt.Errorf("unsupported gossip version: %d", v)
×
NEW
1620
        }
×
1621

1622
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1623
                return forEachChannelWithPolicies(ctx, db, s.cfg, v, cb)
×
1624
        }, reset)
×
1625
}
1626

1627
// FilterChannelRange returns the channel ID's of all known channels which were
1628
// mined in a block height within the passed range. The channel IDs are grouped
1629
// by their common block height. This method can be used to quickly share with a
1630
// peer the set of channels we know of within a particular range to catch them
1631
// up after a period of time offline. If withTimestamps is true then the
1632
// timestamp info of the latest received channel update messages of the channel
1633
// will be included in the response.
1634
//
1635
// NOTE: This is part of the Store interface.
1636
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1637
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1638

×
1639
        var (
×
1640
                ctx       = context.TODO()
×
1641
                startSCID = &lnwire.ShortChannelID{
×
1642
                        BlockHeight: startHeight,
×
1643
                }
×
1644
                endSCID = lnwire.ShortChannelID{
×
1645
                        BlockHeight: endHeight,
×
1646
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1647
                        TxPosition:  math.MaxUint16,
×
1648
                }
×
1649
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1650
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1651
        )
×
1652

×
1653
        // 1) get all channels where channelID is between start and end chan ID.
×
1654
        // 2) skip if not public (ie, no channel_proof)
×
1655
        // 3) collect that channel.
×
1656
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1657
        //    and add those timestamps to the collected channel.
×
1658
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1659
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1660
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1661
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1662
                                StartScid: chanIDStart,
×
1663
                                EndScid:   chanIDEnd,
×
1664
                        },
×
1665
                )
×
1666
                if err != nil {
×
1667
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1668
                                err)
×
1669
                }
×
1670

1671
                for _, dbChan := range dbChans {
×
1672
                        cid := lnwire.NewShortChanIDFromInt(
×
1673
                                byteOrder.Uint64(dbChan.Scid),
×
1674
                        )
×
1675
                        chanInfo := NewChannelUpdateInfo(
×
1676
                                cid, time.Time{}, time.Time{},
×
1677
                        )
×
1678

×
1679
                        if !withTimestamps {
×
1680
                                channelsPerBlock[cid.BlockHeight] = append(
×
1681
                                        channelsPerBlock[cid.BlockHeight],
×
1682
                                        chanInfo,
×
1683
                                )
×
1684

×
1685
                                continue
×
1686
                        }
1687

1688
                        //nolint:ll
1689
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1690
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1691
                                        Version:   int16(lnwire.GossipVersion1),
×
1692
                                        ChannelID: dbChan.ID,
×
1693
                                        NodeID:    dbChan.NodeID1,
×
1694
                                },
×
1695
                        )
×
1696
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1697
                                return fmt.Errorf("unable to fetch node1 "+
×
1698
                                        "policy: %w", err)
×
1699
                        } else if err == nil {
×
1700
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1701
                                        node1Policy.LastUpdate.Int64, 0,
×
1702
                                )
×
1703
                        }
×
1704

1705
                        //nolint:ll
1706
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1707
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1708
                                        Version:   int16(lnwire.GossipVersion1),
×
1709
                                        ChannelID: dbChan.ID,
×
1710
                                        NodeID:    dbChan.NodeID2,
×
1711
                                },
×
1712
                        )
×
1713
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1714
                                return fmt.Errorf("unable to fetch node2 "+
×
1715
                                        "policy: %w", err)
×
1716
                        } else if err == nil {
×
1717
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1718
                                        node2Policy.LastUpdate.Int64, 0,
×
1719
                                )
×
1720
                        }
×
1721

1722
                        channelsPerBlock[cid.BlockHeight] = append(
×
1723
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1724
                        )
×
1725
                }
1726

1727
                return nil
×
1728
        }, func() {
×
1729
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1730
        })
×
1731
        if err != nil {
×
1732
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1733
        }
×
1734

1735
        if len(channelsPerBlock) == 0 {
×
1736
                return nil, nil
×
1737
        }
×
1738

1739
        // Return the channel ranges in ascending block height order.
1740
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1741
        slices.Sort(blocks)
×
1742

×
1743
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1744
                return BlockChannelRange{
×
1745
                        Height:   block,
×
1746
                        Channels: channelsPerBlock[block],
×
1747
                }
×
1748
        }), nil
×
1749
}
1750

1751
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1752
// zombie. This method is used on an ad-hoc basis, when channels need to be
1753
// marked as zombies outside the normal pruning cycle.
1754
//
1755
// NOTE: part of the Store interface.
1756
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1757
        pubKey1, pubKey2 [33]byte) error {
×
1758

×
1759
        ctx := context.TODO()
×
1760

×
1761
        s.cacheMu.Lock()
×
1762
        defer s.cacheMu.Unlock()
×
1763

×
1764
        chanIDB := channelIDToBytes(chanID)
×
1765

×
1766
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1767
                return db.UpsertZombieChannel(
×
1768
                        ctx, sqlc.UpsertZombieChannelParams{
×
1769
                                Version:  int16(lnwire.GossipVersion1),
×
1770
                                Scid:     chanIDB,
×
1771
                                NodeKey1: pubKey1[:],
×
1772
                                NodeKey2: pubKey2[:],
×
1773
                        },
×
1774
                )
×
1775
        }, sqldb.NoOpReset)
×
1776
        if err != nil {
×
1777
                return fmt.Errorf("unable to upsert zombie channel "+
×
1778
                        "(channel_id=%d): %w", chanID, err)
×
1779
        }
×
1780

NEW
1781
        s.rejectCache.remove(lnwire.GossipVersion1, chanID)
×
1782
        s.chanCache.remove(chanID)
×
1783

×
1784
        return nil
×
1785
}
1786

1787
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1788
//
1789
// NOTE: part of the Store interface.
1790
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1791
        s.cacheMu.Lock()
×
1792
        defer s.cacheMu.Unlock()
×
1793

×
1794
        var (
×
1795
                ctx     = context.TODO()
×
1796
                chanIDB = channelIDToBytes(chanID)
×
1797
        )
×
1798

×
1799
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1800
                res, err := db.DeleteZombieChannel(
×
1801
                        ctx, sqlc.DeleteZombieChannelParams{
×
1802
                                Scid:    chanIDB,
×
1803
                                Version: int16(lnwire.GossipVersion1),
×
1804
                        },
×
1805
                )
×
1806
                if err != nil {
×
1807
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1808
                                err)
×
1809
                }
×
1810

1811
                rows, err := res.RowsAffected()
×
1812
                if err != nil {
×
1813
                        return err
×
1814
                }
×
1815

1816
                if rows == 0 {
×
1817
                        return ErrZombieEdgeNotFound
×
1818
                } else if rows > 1 {
×
1819
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1820
                                "expected 1", rows)
×
1821
                }
×
1822

1823
                return nil
×
1824
        }, sqldb.NoOpReset)
1825
        if err != nil {
×
1826
                return fmt.Errorf("unable to mark edge live "+
×
1827
                        "(channel_id=%d): %w", chanID, err)
×
1828
        }
×
1829

NEW
1830
        s.rejectCache.remove(lnwire.GossipVersion1, chanID)
×
1831
        s.chanCache.remove(chanID)
×
1832

×
1833
        return err
×
1834
}
1835

1836
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1837
// zombie, then the two node public keys corresponding to this edge are also
1838
// returned.
1839
//
1840
// NOTE: part of the Store interface.
1841
func (s *SQLStore) IsZombieEdge(v lnwire.GossipVersion,
NEW
1842
        chanID uint64) (bool, [33]byte, [33]byte, error) {
×
1843

×
1844
        var (
×
1845
                ctx              = context.TODO()
×
1846
                isZombie         bool
×
1847
                pubKey1, pubKey2 route.Vertex
×
1848
                chanIDB          = channelIDToBytes(chanID)
×
1849
        )
×
1850

×
NEW
1851
        if !isKnownGossipVersion(v) {
×
NEW
1852
                return false, [33]byte{}, [33]byte{},
×
NEW
1853
                        fmt.Errorf("unsupported gossip version: %d", v)
×
NEW
1854
        }
×
1855

1856
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1857
                zombie, err := db.GetZombieChannel(
×
1858
                        ctx, sqlc.GetZombieChannelParams{
×
1859
                                Scid:    chanIDB,
×
NEW
1860
                                Version: int16(v),
×
1861
                        },
×
1862
                )
×
1863
                if errors.Is(err, sql.ErrNoRows) {
×
1864
                        return nil
×
1865
                }
×
1866
                if err != nil {
×
1867
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1868
                                err)
×
1869
                }
×
1870

1871
                copy(pubKey1[:], zombie.NodeKey1)
×
1872
                copy(pubKey2[:], zombie.NodeKey2)
×
1873
                isZombie = true
×
1874

×
1875
                return nil
×
1876
        }, sqldb.NoOpReset)
1877
        if err != nil {
×
1878
                return false, route.Vertex{}, route.Vertex{},
×
1879
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1880
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1881
        }
×
1882

1883
        return isZombie, pubKey1, pubKey2, nil
×
1884
}
1885

1886
// NumZombies returns the current number of zombie channels in the graph.
1887
//
1888
// NOTE: part of the Store interface.
1889
func (s *SQLStore) NumZombies() (uint64, error) {
×
1890
        var (
×
1891
                ctx        = context.TODO()
×
1892
                numZombies uint64
×
1893
        )
×
1894
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1895
                count, err := db.CountZombieChannels(
×
1896
                        ctx, int16(lnwire.GossipVersion1),
×
1897
                )
×
1898
                if err != nil {
×
1899
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1900
                                err)
×
1901
                }
×
1902

1903
                numZombies = uint64(count)
×
1904

×
1905
                return nil
×
1906
        }, sqldb.NoOpReset)
1907
        if err != nil {
×
1908
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1909
        }
×
1910

1911
        return numZombies, nil
×
1912
}
1913

1914
// DeleteChannelEdges removes edges with the given channel IDs from the
1915
// database and marks them as zombies. This ensures that we're unable to re-add
1916
// it to our database once again. If an edge does not exist within the
1917
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1918
// true, then when we mark these edges as zombies, we'll set up the keys such
1919
// that we require the node that failed to send the fresh update to be the one
1920
// that resurrects the channel from its zombie state. The markZombie bool
1921
// denotes whether to mark the channel as a zombie.
1922
//
1923
// NOTE: part of the Store interface.
1924
func (s *SQLStore) DeleteChannelEdges(v lnwire.GossipVersion,
1925
        strictZombiePruning, markZombie bool, chanIDs ...uint64) (
NEW
1926
        []*models.ChannelEdgeInfo, error) {
×
1927

×
1928
        s.cacheMu.Lock()
×
1929
        defer s.cacheMu.Unlock()
×
1930

×
1931
        // Keep track of which channels we end up finding so that we can
×
1932
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1933
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1934
        for _, chanID := range chanIDs {
×
1935
                chanLookup[chanID] = struct{}{}
×
1936
        }
×
1937

1938
        var (
×
1939
                ctx   = context.TODO()
×
1940
                edges []*models.ChannelEdgeInfo
×
1941
        )
×
1942
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1943
                // First, collect all channel rows.
×
1944
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1945
                chanCallBack := func(ctx context.Context,
×
1946
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1947

×
1948
                        // Deleting the entry from the map indicates that we
×
1949
                        // have found the channel.
×
1950
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1951
                        delete(chanLookup, scid)
×
1952

×
1953
                        channelRows = append(channelRows, row)
×
1954

×
1955
                        return nil
×
1956
                }
×
1957

1958
                err := s.forEachChanWithPoliciesInSCIDList(
×
NEW
1959
                        ctx, db, v, chanCallBack, chanIDs,
×
1960
                )
×
1961
                if err != nil {
×
1962
                        return err
×
1963
                }
×
1964

1965
                if len(chanLookup) > 0 {
×
1966
                        return ErrEdgeNotFound
×
1967
                }
×
1968

1969
                if len(channelRows) == 0 {
×
1970
                        return nil
×
1971
                }
×
1972

1973
                // Batch build all channel edges.
1974
                var chanIDsToDelete []int64
×
1975
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1976
                        ctx, s.cfg, db, channelRows,
×
1977
                )
×
1978
                if err != nil {
×
1979
                        return err
×
1980
                }
×
1981

1982
                if markZombie {
×
1983
                        for i, row := range channelRows {
×
1984
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1985

×
1986
                                err := handleZombieMarking(
×
NEW
1987
                                        ctx, db, v, row, edges[i],
×
1988
                                        strictZombiePruning, scid,
×
1989
                                )
×
1990
                                if err != nil {
×
1991
                                        return fmt.Errorf("unable to mark "+
×
1992
                                                "channel as zombie: %w", err)
×
1993
                                }
×
1994
                        }
1995
                }
1996

1997
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1998
        }, func() {
×
1999
                edges = nil
×
2000

×
2001
                // Re-fill the lookup map.
×
2002
                for _, chanID := range chanIDs {
×
2003
                        chanLookup[chanID] = struct{}{}
×
2004
                }
×
2005
        })
2006
        if err != nil {
×
2007
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
2008
                        err)
×
2009
        }
×
2010

2011
        for _, chanID := range chanIDs {
×
NEW
2012
                s.rejectCache.remove(v, chanID)
×
2013
                s.chanCache.remove(chanID)
×
2014
        }
×
2015

2016
        return edges, nil
×
2017
}
2018

2019
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
2020
// channel identified by the channel ID. If the channel can't be found, then
2021
// ErrEdgeNotFound is returned. A struct which houses the general information
2022
// for the channel itself is returned as well as two structs that contain the
2023
// routing policies for the channel in either direction.
2024
//
2025
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
2026
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
2027
// the ChannelEdgeInfo will only include the public keys of each node.
2028
//
2029
// NOTE: part of the Store interface.
2030
func (s *SQLStore) FetchChannelEdgesByID(v lnwire.GossipVersion,
2031
        chanID uint64) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2032
        *models.ChannelEdgePolicy, error) {
×
2033

×
2034
        var (
×
2035
                ctx              = context.TODO()
×
2036
                edge             *models.ChannelEdgeInfo
×
2037
                policy1, policy2 *models.ChannelEdgePolicy
×
2038
                chanIDB          = channelIDToBytes(chanID)
×
2039
        )
×
NEW
2040

×
NEW
2041
        if !isKnownGossipVersion(v) {
×
NEW
2042
                return nil, nil, nil, fmt.Errorf(
×
NEW
2043
                        "unsupported gossip version: %d", v,
×
NEW
2044
                )
×
NEW
2045
        }
×
2046

2047
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2048
                row, err := db.GetChannelBySCIDWithPolicies(
×
2049
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2050
                                Scid:    chanIDB,
×
NEW
2051
                                Version: int16(v),
×
2052
                        },
×
2053
                )
×
2054
                if errors.Is(err, sql.ErrNoRows) {
×
2055
                        // First check if this edge is perhaps in the zombie
×
2056
                        // index.
×
2057
                        zombie, err := db.GetZombieChannel(
×
2058
                                ctx, sqlc.GetZombieChannelParams{
×
2059
                                        Scid:    chanIDB,
×
NEW
2060
                                        Version: int16(v),
×
2061
                                },
×
2062
                        )
×
2063
                        if errors.Is(err, sql.ErrNoRows) {
×
2064
                                return ErrEdgeNotFound
×
2065
                        } else if err != nil {
×
2066
                                return fmt.Errorf("unable to check if "+
×
2067
                                        "channel is zombie: %w", err)
×
2068
                        }
×
2069

2070
                        // At this point, we know the channel is a zombie, so
2071
                        // we'll return an error indicating this, and we will
2072
                        // populate the edge info with the public keys of each
2073
                        // party as this is the only information we have about
2074
                        // it.
NEW
2075
                        node1, err := route.NewVertexFromBytes(zombie.NodeKey1)
×
NEW
2076
                        if err != nil {
×
NEW
2077
                                return err
×
NEW
2078
                        }
×
NEW
2079
                        node2, err := route.NewVertexFromBytes(zombie.NodeKey2)
×
NEW
2080
                        if err != nil {
×
NEW
2081
                                return err
×
NEW
2082
                        }
×
NEW
2083
                        zombieEdge, err := models.NewV1Channel(
×
NEW
2084
                                0, chainhash.Hash{}, node1, node2,
×
NEW
2085
                                &models.ChannelV1Fields{},
×
NEW
2086
                        )
×
NEW
2087
                        if err != nil {
×
NEW
2088
                                return err
×
NEW
2089
                        }
×
NEW
2090
                        edge = zombieEdge
×
2091

×
2092
                        return ErrZombieEdge
×
2093
                } else if err != nil {
×
2094
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2095
                }
×
2096

2097
                node1, node2, err := buildNodeVertices(
×
2098
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2099
                )
×
2100
                if err != nil {
×
2101
                        return err
×
2102
                }
×
2103

2104
                edge, err = getAndBuildEdgeInfo(
×
2105
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2106
                )
×
2107
                if err != nil {
×
2108
                        return fmt.Errorf("unable to build channel info: %w",
×
2109
                                err)
×
2110
                }
×
2111

2112
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2113
                if err != nil {
×
2114
                        return fmt.Errorf("unable to extract channel "+
×
2115
                                "policies: %w", err)
×
2116
                }
×
2117

2118
                policy1, policy2, err = getAndBuildChanPolicies(
×
2119
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2120
                        node1, node2,
×
2121
                )
×
2122
                if err != nil {
×
2123
                        return fmt.Errorf("unable to build channel "+
×
2124
                                "policies: %w", err)
×
2125
                }
×
2126

2127
                return nil
×
2128
        }, sqldb.NoOpReset)
2129
        if err != nil {
×
2130
                // If we are returning the ErrZombieEdge, then we also need to
×
2131
                // return the edge info as the method comment indicates that
×
2132
                // this will be populated when the edge is a zombie.
×
2133
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2134
                        err)
×
2135
        }
×
2136

2137
        return edge, policy1, policy2, nil
×
2138
}
2139

2140
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2141
// the channel identified by the funding outpoint. If the channel can't be
2142
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2143
// information for the channel itself is returned as well as two structs that
2144
// contain the routing policies for the channel in either direction.
2145
//
2146
// NOTE: part of the Store interface.
2147
func (s *SQLStore) FetchChannelEdgesByOutpoint(v lnwire.GossipVersion,
2148
        op *wire.OutPoint) (*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2149
        *models.ChannelEdgePolicy, error) {
×
2150

×
2151
        var (
×
2152
                ctx              = context.TODO()
×
2153
                edge             *models.ChannelEdgeInfo
×
2154
                policy1, policy2 *models.ChannelEdgePolicy
×
2155
        )
×
NEW
2156

×
NEW
2157
        if !isKnownGossipVersion(v) {
×
NEW
2158
                return nil, nil, nil, fmt.Errorf(
×
NEW
2159
                        "unsupported gossip version: %d", v,
×
NEW
2160
                )
×
NEW
2161
        }
×
2162

2163
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2164
                row, err := db.GetChannelByOutpointWithPolicies(
×
2165
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2166
                                Outpoint: op.String(),
×
NEW
2167
                                Version:  int16(v),
×
2168
                        },
×
2169
                )
×
2170
                if errors.Is(err, sql.ErrNoRows) {
×
2171
                        return ErrEdgeNotFound
×
2172
                } else if err != nil {
×
2173
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2174
                }
×
2175

2176
                node1, node2, err := buildNodeVertices(
×
2177
                        row.Node1Pubkey, row.Node2Pubkey,
×
2178
                )
×
2179
                if err != nil {
×
2180
                        return err
×
2181
                }
×
2182

2183
                edge, err = getAndBuildEdgeInfo(
×
2184
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2185
                )
×
2186
                if err != nil {
×
2187
                        return fmt.Errorf("unable to build channel info: %w",
×
2188
                                err)
×
2189
                }
×
2190

2191
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2192
                if err != nil {
×
2193
                        return fmt.Errorf("unable to extract channel "+
×
2194
                                "policies: %w", err)
×
2195
                }
×
2196

2197
                policy1, policy2, err = getAndBuildChanPolicies(
×
2198
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2199
                        node1, node2,
×
2200
                )
×
2201
                if err != nil {
×
2202
                        return fmt.Errorf("unable to build channel "+
×
2203
                                "policies: %w", err)
×
2204
                }
×
2205

2206
                return nil
×
2207
        }, sqldb.NoOpReset)
2208
        if err != nil {
×
2209
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2210
                        err)
×
2211
        }
×
2212

2213
        return edge, policy1, policy2, nil
×
2214
}
2215

2216
// HasV1ChannelEdge returns true if the database knows of a channel edge
2217
// with the passed channel ID, and false otherwise. If an edge with that ID
2218
// is found within the graph, then two time stamps representing the last time
2219
// the edge was updated for both directed edges are returned along with the
2220
// boolean. If it is not found, then the zombie index is checked and its
2221
// result is returned as the second boolean.
2222
//
2223
// NOTE: part of the Store interface.
2224
func (s *SQLStore) HasV1ChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2225
        bool, error) {
×
2226

×
2227
        ctx := context.TODO()
×
2228

×
2229
        var (
×
2230
                exists          bool
×
2231
                isZombie        bool
×
2232
                node1LastUpdate time.Time
×
2233
                node2LastUpdate time.Time
×
2234
        )
×
2235

×
2236
        // We'll query the cache with the shared lock held to allow multiple
×
2237
        // readers to access values in the cache concurrently if they exist.
×
2238
        s.cacheMu.RLock()
×
NEW
2239
        if entry, ok := s.rejectCache.get(lnwire.GossipVersion1, chanID); ok {
×
2240
                s.cacheMu.RUnlock()
×
2241
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2242
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2243
                exists, isZombie = entry.flags.unpack()
×
2244

×
2245
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2246
        }
×
2247
        s.cacheMu.RUnlock()
×
2248

×
2249
        s.cacheMu.Lock()
×
2250
        defer s.cacheMu.Unlock()
×
2251

×
2252
        // The item was not found with the shared lock, so we'll acquire the
×
2253
        // exclusive lock and check the cache again in case another method added
×
2254
        // the entry to the cache while no lock was held.
×
NEW
2255
        if entry, ok := s.rejectCache.get(lnwire.GossipVersion1, chanID); ok {
×
2256
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2257
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2258
                exists, isZombie = entry.flags.unpack()
×
2259

×
2260
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2261
        }
×
2262

2263
        chanIDB := channelIDToBytes(chanID)
×
2264
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2265
                channel, err := db.GetChannelBySCID(
×
2266
                        ctx, sqlc.GetChannelBySCIDParams{
×
2267
                                Scid:    chanIDB,
×
2268
                                Version: int16(lnwire.GossipVersion1),
×
2269
                        },
×
2270
                )
×
2271
                if errors.Is(err, sql.ErrNoRows) {
×
2272
                        // Check if it is a zombie channel.
×
2273
                        isZombie, err = db.IsZombieChannel(
×
2274
                                ctx, sqlc.IsZombieChannelParams{
×
2275
                                        Scid:    chanIDB,
×
2276
                                        Version: int16(lnwire.GossipVersion1),
×
2277
                                },
×
2278
                        )
×
2279
                        if err != nil {
×
2280
                                return fmt.Errorf("could not check if channel "+
×
2281
                                        "is zombie: %w", err)
×
2282
                        }
×
2283

2284
                        return nil
×
2285
                } else if err != nil {
×
2286
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2287
                }
×
2288

2289
                exists = true
×
2290

×
2291
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2292
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2293
                                Version:   int16(lnwire.GossipVersion1),
×
2294
                                ChannelID: channel.ID,
×
2295
                                NodeID:    channel.NodeID1,
×
2296
                        },
×
2297
                )
×
2298
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2299
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2300
                                err)
×
2301
                } else if err == nil {
×
2302
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2303
                }
×
2304

2305
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2306
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2307
                                Version:   int16(lnwire.GossipVersion1),
×
2308
                                ChannelID: channel.ID,
×
2309
                                NodeID:    channel.NodeID2,
×
2310
                        },
×
2311
                )
×
2312
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2313
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2314
                                err)
×
2315
                } else if err == nil {
×
2316
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2317
                }
×
2318

2319
                return nil
×
2320
        }, sqldb.NoOpReset)
2321
        if err != nil {
×
2322
                return time.Time{}, time.Time{}, false, false,
×
2323
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2324
        }
×
2325

NEW
2326
        s.rejectCache.insert(
×
NEW
2327
                lnwire.GossipVersion1, chanID,
×
NEW
2328
                newRejectCacheEntryV1(
×
NEW
2329
                        node1LastUpdate, node2LastUpdate, exists,
×
NEW
2330
                        isZombie,
×
NEW
2331
                ),
×
NEW
2332
        )
×
2333

×
2334
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2335
}
2336

2337
// HasChannelEdge returns true if the database knows of a channel edge with the
2338
// passed channel ID and gossip version, and false otherwise. If an edge with
2339
// that ID is found within the graph, then the zombie index is checked and its
2340
// result is returned as the second boolean.
2341
//
2342
// NOTE: part of the Store interface.
2343
func (s *SQLStore) HasChannelEdge(v lnwire.GossipVersion,
NEW
2344
        chanID uint64) (bool, bool, error) {
×
NEW
2345

×
NEW
2346
        if !isKnownGossipVersion(v) {
×
NEW
2347
                return false, false, fmt.Errorf(
×
NEW
2348
                        "unsupported gossip version: %d", v,
×
NEW
2349
                )
×
NEW
2350
        }
×
2351

NEW
2352
        ctx := context.TODO()
×
NEW
2353

×
NEW
2354
        var (
×
NEW
2355
                exists          bool
×
NEW
2356
                isZombie        bool
×
NEW
2357
                node1LastUpdate time.Time
×
NEW
2358
                node2LastUpdate time.Time
×
NEW
2359
                node1Block      uint32
×
NEW
2360
                node2Block      uint32
×
NEW
2361
        )
×
NEW
2362

×
NEW
2363
        // We'll query the cache with the shared lock held to allow multiple
×
NEW
2364
        // readers to access values in the cache concurrently if they exist.
×
NEW
2365
        s.cacheMu.RLock()
×
NEW
2366
        if entry, ok := s.rejectCache.get(v, chanID); ok {
×
NEW
2367
                s.cacheMu.RUnlock()
×
NEW
2368
                exists, isZombie = entry.flags.unpack()
×
NEW
2369
                return exists, isZombie, nil
×
NEW
2370
        }
×
NEW
2371
        s.cacheMu.RUnlock()
×
NEW
2372

×
NEW
2373
        s.cacheMu.Lock()
×
NEW
2374
        defer s.cacheMu.Unlock()
×
NEW
2375

×
NEW
2376
        // The item was not found with the shared lock, so we'll acquire the
×
NEW
2377
        // exclusive lock and check the cache again in case another method added
×
NEW
2378
        // the entry to the cache while no lock was held.
×
NEW
2379
        if entry, ok := s.rejectCache.get(v, chanID); ok {
×
NEW
2380
                exists, isZombie = entry.flags.unpack()
×
NEW
2381
                return exists, isZombie, nil
×
NEW
2382
        }
×
2383

NEW
2384
        chanIDB := channelIDToBytes(chanID)
×
NEW
2385
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
2386
                channel, err := db.GetChannelBySCID(
×
NEW
2387
                        ctx, sqlc.GetChannelBySCIDParams{
×
NEW
2388
                                Scid:    chanIDB,
×
NEW
2389
                                Version: int16(v),
×
NEW
2390
                        },
×
NEW
2391
                )
×
NEW
2392
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
2393
                        // Check if it is a zombie channel.
×
NEW
2394
                        isZombie, err = db.IsZombieChannel(
×
NEW
2395
                                ctx, sqlc.IsZombieChannelParams{
×
NEW
2396
                                        Scid:    chanIDB,
×
NEW
2397
                                        Version: int16(v),
×
NEW
2398
                                },
×
NEW
2399
                        )
×
NEW
2400
                        if err != nil {
×
NEW
2401
                                return fmt.Errorf("could not check if channel "+
×
NEW
2402
                                        "is zombie: %w", err)
×
NEW
2403
                        }
×
2404

NEW
2405
                        return nil
×
NEW
2406
                } else if err != nil {
×
NEW
2407
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
2408
                }
×
2409

NEW
2410
                exists = true
×
NEW
2411

×
NEW
2412
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
2413
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
2414
                                Version:   int16(v),
×
NEW
2415
                                ChannelID: channel.ID,
×
NEW
2416
                                NodeID:    channel.NodeID1,
×
NEW
2417
                        },
×
NEW
2418
                )
×
NEW
2419
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
2420
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
NEW
2421
                                err)
×
NEW
2422
                } else if err == nil {
×
NEW
2423
                        switch v {
×
NEW
2424
                        case lnwire.GossipVersion1:
×
NEW
2425
                                if policy1.LastUpdate.Valid {
×
NEW
2426
                                        node1LastUpdate = time.Unix(
×
NEW
2427
                                                policy1.LastUpdate.Int64, 0,
×
NEW
2428
                                        )
×
NEW
2429
                                }
×
NEW
2430
                        case lnwire.GossipVersion2:
×
NEW
2431
                                if policy1.BlockHeight.Valid {
×
NEW
2432
                                        node1Block = uint32(
×
NEW
2433
                                                policy1.BlockHeight.Int64,
×
NEW
2434
                                        )
×
NEW
2435
                                }
×
2436
                        }
2437
                }
2438

NEW
2439
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
NEW
2440
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
2441
                                Version:   int16(v),
×
NEW
2442
                                ChannelID: channel.ID,
×
NEW
2443
                                NodeID:    channel.NodeID2,
×
NEW
2444
                        },
×
NEW
2445
                )
×
NEW
2446
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
NEW
2447
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
NEW
2448
                                err)
×
NEW
2449
                } else if err == nil {
×
NEW
2450
                        switch v {
×
NEW
2451
                        case lnwire.GossipVersion1:
×
NEW
2452
                                if policy2.LastUpdate.Valid {
×
NEW
2453
                                        node2LastUpdate = time.Unix(
×
NEW
2454
                                                policy2.LastUpdate.Int64, 0,
×
NEW
2455
                                        )
×
NEW
2456
                                }
×
NEW
2457
                        case lnwire.GossipVersion2:
×
NEW
2458
                                if policy2.BlockHeight.Valid {
×
NEW
2459
                                        node2Block = uint32(
×
NEW
2460
                                                policy2.BlockHeight.Int64,
×
NEW
2461
                                        )
×
NEW
2462
                                }
×
2463
                        }
2464
                }
2465

NEW
2466
                return nil
×
2467
        }, sqldb.NoOpReset)
NEW
2468
        if err != nil {
×
NEW
2469
                return false, false,
×
NEW
2470
                        fmt.Errorf("unable to fetch channel: %w", err)
×
NEW
2471
        }
×
2472

NEW
2473
        var entry rejectCacheEntry
×
NEW
2474
        switch v {
×
NEW
2475
        case lnwire.GossipVersion1:
×
NEW
2476
                entry = newRejectCacheEntryV1(
×
NEW
2477
                        node1LastUpdate, node2LastUpdate, exists, isZombie,
×
NEW
2478
                )
×
NEW
2479
        case lnwire.GossipVersion2:
×
NEW
2480
                entry = newRejectCacheEntryV2(
×
NEW
2481
                        node1Block, node2Block, exists, isZombie,
×
NEW
2482
                )
×
2483
        }
NEW
2484
        s.rejectCache.insert(v, chanID, entry)
×
NEW
2485

×
NEW
2486
        return exists, isZombie, nil
×
2487
}
2488

2489
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2490
// passed channel point (outpoint). If the passed channel doesn't exist within
2491
// the database, then ErrEdgeNotFound is returned.
2492
//
2493
// NOTE: part of the Store interface.
2494
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2495
        var (
×
2496
                ctx       = context.TODO()
×
2497
                channelID uint64
×
2498
        )
×
2499
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2500
                chanID, err := db.GetSCIDByOutpoint(
×
2501
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2502
                                Outpoint: chanPoint.String(),
×
2503
                                Version:  int16(lnwire.GossipVersion1),
×
2504
                        },
×
2505
                )
×
2506
                if errors.Is(err, sql.ErrNoRows) {
×
2507
                        return ErrEdgeNotFound
×
2508
                } else if err != nil {
×
2509
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2510
                                err)
×
2511
                }
×
2512

2513
                channelID = byteOrder.Uint64(chanID)
×
2514

×
2515
                return nil
×
2516
        }, sqldb.NoOpReset)
2517
        if err != nil {
×
2518
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2519
        }
×
2520

2521
        return channelID, nil
×
2522
}
2523

2524
// IsPublicNode is a helper method that determines whether the node with the
2525
// given public key is seen as a public node in the graph from the graph's
2526
// source node's point of view.
2527
//
2528
// NOTE: part of the Store interface.
2529
func (s *SQLStore) IsPublicNode(v lnwire.GossipVersion, pubKey [33]byte) (bool,
NEW
2530
        error) {
×
NEW
2531

×
2532
        ctx := context.TODO()
×
NEW
2533
        if !isKnownGossipVersion(v) {
×
NEW
2534
                return false, fmt.Errorf("unsupported gossip version: %d", v)
×
NEW
2535
        }
×
2536

2537
        var isPublic bool
×
2538
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2539
                var err error
×
NEW
2540
                switch v {
×
NEW
2541
                case lnwire.GossipVersion1:
×
NEW
2542
                        isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
NEW
2543
                case lnwire.GossipVersion2:
×
NEW
2544
                        isPublic, err = db.IsPublicV2Node(ctx, pubKey[:])
×
2545
                }
2546

2547
                return err
×
2548
        }, sqldb.NoOpReset)
2549
        if err != nil {
×
2550
                return false, fmt.Errorf("unable to check if node is "+
×
2551
                        "public: %w", err)
×
2552
        }
×
2553

2554
        return isPublic, nil
×
2555
}
2556

2557
// FetchChanInfos returns the set of channel edges that correspond to the passed
2558
// channel ID's. If an edge is the query is unknown to the database, it will
2559
// skipped and the result will contain only those edges that exist at the time
2560
// of the query. This can be used to respond to peer queries that are seeking to
2561
// fill in gaps in their view of the channel graph.
2562
//
2563
// NOTE: part of the Store interface.
2564
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2565
        var (
×
2566
                ctx   = context.TODO()
×
2567
                edges = make(map[uint64]ChannelEdge)
×
2568
        )
×
2569
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2570
                // First, collect all channel rows.
×
2571
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2572
                chanCallBack := func(ctx context.Context,
×
2573
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2574

×
2575
                        channelRows = append(channelRows, row)
×
2576
                        return nil
×
2577
                }
×
2578

2579
                err := s.forEachChanWithPoliciesInSCIDList(
×
NEW
2580
                        ctx, db, lnwire.GossipVersion1, chanCallBack, chanIDs,
×
2581
                )
×
2582
                if err != nil {
×
2583
                        return err
×
2584
                }
×
2585

2586
                if len(channelRows) == 0 {
×
2587
                        return nil
×
2588
                }
×
2589

2590
                // Batch build all channel edges.
2591
                chans, err := batchBuildChannelEdges(
×
2592
                        ctx, s.cfg, db, channelRows,
×
2593
                )
×
2594
                if err != nil {
×
2595
                        return fmt.Errorf("unable to build channel edges: %w",
×
2596
                                err)
×
2597
                }
×
2598

2599
                for _, c := range chans {
×
2600
                        edges[c.Info.ChannelID] = c
×
2601
                }
×
2602

2603
                return err
×
2604
        }, func() {
×
2605
                clear(edges)
×
2606
        })
×
2607
        if err != nil {
×
2608
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2609
        }
×
2610

2611
        res := make([]ChannelEdge, 0, len(edges))
×
2612
        for _, chanID := range chanIDs {
×
2613
                edge, ok := edges[chanID]
×
2614
                if !ok {
×
2615
                        continue
×
2616
                }
2617

2618
                res = append(res, edge)
×
2619
        }
2620

2621
        return res, nil
×
2622
}
2623

2624
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2625
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2626
// channels in a paginated manner.
2627
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2628
        db SQLQueries, v lnwire.GossipVersion, cb func(ctx context.Context,
2629
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2630
        chanIDs []uint64) error {
×
2631

×
2632
        queryWrapper := func(ctx context.Context,
×
2633
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2634
                error) {
×
2635

×
2636
                return db.GetChannelsBySCIDWithPolicies(
×
2637
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
NEW
2638
                                Version: int16(v),
×
2639
                                Scids:   scids,
×
2640
                        },
×
2641
                )
×
2642
        }
×
2643

2644
        return sqldb.ExecuteBatchQuery(
×
2645
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2646
                cb,
×
2647
        )
×
2648
}
2649

2650
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2651
// ID's that we don't know and are not known zombies of the passed set. In other
2652
// words, we perform a set difference of our set of chan ID's and the ones
2653
// passed in. This method can be used by callers to determine the set of
2654
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2655
// known zombies is also returned.
2656
//
2657
// NOTE: part of the Store interface.
2658
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2659
        []ChannelUpdateInfo, error) {
×
2660

×
2661
        var (
×
2662
                ctx          = context.TODO()
×
2663
                newChanIDs   []uint64
×
2664
                knownZombies []ChannelUpdateInfo
×
2665
                infoLookup   = make(
×
2666
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2667
                )
×
2668
        )
×
2669

×
2670
        // We first build a lookup map of the channel ID's to the
×
2671
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2672
        // already know about.
×
2673
        for _, chanInfo := range chansInfo {
×
2674
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2675
        }
×
2676

2677
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2678
                // The call-back function deletes known channels from
×
2679
                // infoLookup, so that we can later check which channels are
×
2680
                // zombies by only looking at the remaining channels in the set.
×
2681
                cb := func(ctx context.Context,
×
2682
                        channel sqlc.GraphChannel) error {
×
2683

×
2684
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2685

×
2686
                        return nil
×
2687
                }
×
2688

2689
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2690
                if err != nil {
×
2691
                        return fmt.Errorf("unable to iterate through "+
×
2692
                                "channels: %w", err)
×
2693
                }
×
2694

2695
                // We want to ensure that we deal with the channels in the
2696
                // same order that they were passed in, so we iterate over the
2697
                // original chansInfo slice and then check if that channel is
2698
                // still in the infoLookup map.
2699
                for _, chanInfo := range chansInfo {
×
2700
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2701
                        if _, ok := infoLookup[channelID]; !ok {
×
2702
                                continue
×
2703
                        }
2704

2705
                        isZombie, err := db.IsZombieChannel(
×
2706
                                ctx, sqlc.IsZombieChannelParams{
×
2707
                                        Scid:    channelIDToBytes(channelID),
×
2708
                                        Version: int16(lnwire.GossipVersion1),
×
2709
                                },
×
2710
                        )
×
2711
                        if err != nil {
×
2712
                                return fmt.Errorf("unable to fetch zombie "+
×
2713
                                        "channel: %w", err)
×
2714
                        }
×
2715

2716
                        if isZombie {
×
2717
                                knownZombies = append(knownZombies, chanInfo)
×
2718

×
2719
                                continue
×
2720
                        }
2721

2722
                        newChanIDs = append(newChanIDs, channelID)
×
2723
                }
2724

2725
                return nil
×
2726
        }, func() {
×
2727
                newChanIDs = nil
×
2728
                knownZombies = nil
×
2729
                // Rebuild the infoLookup map in case of a rollback.
×
2730
                for _, chanInfo := range chansInfo {
×
2731
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2732
                        infoLookup[scid] = chanInfo
×
2733
                }
×
2734
        })
2735
        if err != nil {
×
2736
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2737
        }
×
2738

2739
        return newChanIDs, knownZombies, nil
×
2740
}
2741

2742
// forEachChanInSCIDList is a helper method that executes a paged query
2743
// against the database to fetch all channels that match the passed
2744
// ChannelUpdateInfo slice. The callback function is called for each channel
2745
// that is found.
2746
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2747
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2748
        chansInfo []ChannelUpdateInfo) error {
×
2749

×
2750
        queryWrapper := func(ctx context.Context,
×
2751
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2752

×
2753
                return db.GetChannelsBySCIDs(
×
2754
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2755
                                Version: int16(lnwire.GossipVersion1),
×
2756
                                Scids:   scids,
×
2757
                        },
×
2758
                )
×
2759
        }
×
2760

2761
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2762
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2763

×
2764
                return channelIDToBytes(channelID)
×
2765
        }
×
2766

2767
        return sqldb.ExecuteBatchQuery(
×
2768
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2769
                cb,
×
2770
        )
×
2771
}
2772

2773
// PruneGraphNodes is a garbage collection method which attempts to prune out
2774
// any nodes from the channel graph that are currently unconnected. This ensure
2775
// that we only maintain a graph of reachable nodes. In the event that a pruned
2776
// node gains more channels, it will be re-added back to the graph.
2777
//
2778
// NOTE: this prunes nodes across protocol versions. It will never prune the
2779
// source nodes.
2780
//
2781
// NOTE: part of the Store interface.
2782
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2783
        var ctx = context.TODO()
×
2784

×
2785
        var prunedNodes []route.Vertex
×
2786
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2787
                var err error
×
2788
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2789

×
2790
                return err
×
2791
        }, func() {
×
2792
                prunedNodes = nil
×
2793
        })
×
2794
        if err != nil {
×
2795
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2796
        }
×
2797

2798
        return prunedNodes, nil
×
2799
}
2800

2801
// PruneGraph prunes newly closed channels from the channel graph in response
2802
// to a new block being solved on the network. Any transactions which spend the
2803
// funding output of any known channels within he graph will be deleted.
2804
// Additionally, the "prune tip", or the last block which has been used to
2805
// prune the graph is stored so callers can ensure the graph is fully in sync
2806
// with the current UTXO state. A slice of channels that have been closed by
2807
// the target block along with any pruned nodes are returned if the function
2808
// succeeds without error.
2809
//
2810
// NOTE: part of the Store interface.
2811
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2812
        blockHash *chainhash.Hash, blockHeight uint32) (
2813
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2814

×
2815
        ctx := context.TODO()
×
2816

×
2817
        s.cacheMu.Lock()
×
2818
        defer s.cacheMu.Unlock()
×
2819

×
2820
        var (
×
2821
                closedChans []*models.ChannelEdgeInfo
×
2822
                prunedNodes []route.Vertex
×
2823
        )
×
2824
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2825
                // First, collect all channel rows that need to be pruned.
×
2826
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2827
                channelCallback := func(ctx context.Context,
×
2828
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2829

×
2830
                        channelRows = append(channelRows, row)
×
2831

×
2832
                        return nil
×
2833
                }
×
2834

2835
                err := s.forEachChanInOutpoints(
×
2836
                        ctx, db, spentOutputs, channelCallback,
×
2837
                )
×
2838
                if err != nil {
×
2839
                        return fmt.Errorf("unable to fetch channels by "+
×
2840
                                "outpoints: %w", err)
×
2841
                }
×
2842

2843
                if len(channelRows) == 0 {
×
2844
                        // There are no channels to prune. So we can exit early
×
2845
                        // after updating the prune log.
×
2846
                        err = db.UpsertPruneLogEntry(
×
2847
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2848
                                        BlockHash:   blockHash[:],
×
2849
                                        BlockHeight: int64(blockHeight),
×
2850
                                },
×
2851
                        )
×
2852
                        if err != nil {
×
2853
                                return fmt.Errorf("unable to insert prune log "+
×
2854
                                        "entry: %w", err)
×
2855
                        }
×
2856

2857
                        return nil
×
2858
                }
2859

2860
                // Batch build all channel edges for pruning.
2861
                var chansToDelete []int64
×
2862
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2863
                        ctx, s.cfg, db, channelRows,
×
2864
                )
×
2865
                if err != nil {
×
2866
                        return err
×
2867
                }
×
2868

2869
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2870
                if err != nil {
×
2871
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2872
                }
×
2873

2874
                err = db.UpsertPruneLogEntry(
×
2875
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2876
                                BlockHash:   blockHash[:],
×
2877
                                BlockHeight: int64(blockHeight),
×
2878
                        },
×
2879
                )
×
2880
                if err != nil {
×
2881
                        return fmt.Errorf("unable to insert prune log "+
×
2882
                                "entry: %w", err)
×
2883
                }
×
2884

2885
                // Now that we've pruned some channels, we'll also prune any
2886
                // nodes that no longer have any channels.
2887
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2888
                if err != nil {
×
2889
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2890
                                err)
×
2891
                }
×
2892

2893
                return nil
×
2894
        }, func() {
×
2895
                prunedNodes = nil
×
2896
                closedChans = nil
×
2897
        })
×
2898
        if err != nil {
×
2899
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2900
        }
×
2901

2902
        for _, channel := range closedChans {
×
NEW
2903
                s.rejectCache.remove(channel.Version, channel.ChannelID)
×
2904
                s.chanCache.remove(channel.ChannelID)
×
2905
        }
×
2906

2907
        return closedChans, prunedNodes, nil
×
2908
}
2909

2910
// forEachChanInOutpoints is a helper function that executes a paginated
2911
// query to fetch channels by their outpoints and applies the given call-back
2912
// to each.
2913
//
2914
// NOTE: this fetches channels for all protocol versions.
2915
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2916
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2917
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2918

×
2919
        // Create a wrapper that uses the transaction's db instance to execute
×
2920
        // the query.
×
2921
        queryWrapper := func(ctx context.Context,
×
2922
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2923
                error) {
×
2924

×
2925
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2926
        }
×
2927

2928
        // Define the conversion function from Outpoint to string.
2929
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2930
                return outpoint.String()
×
2931
        }
×
2932

2933
        return sqldb.ExecuteBatchQuery(
×
2934
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2935
                queryWrapper, cb,
×
2936
        )
×
2937
}
2938

2939
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2940
        dbIDs []int64) error {
×
2941

×
2942
        // Create a wrapper that uses the transaction's db instance to execute
×
2943
        // the query.
×
2944
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2945
                return nil, db.DeleteChannels(ctx, ids)
×
2946
        }
×
2947

2948
        idConverter := func(id int64) int64 {
×
2949
                return id
×
2950
        }
×
2951

2952
        return sqldb.ExecuteBatchQuery(
×
2953
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2954
                queryWrapper, func(ctx context.Context, _ any) error {
×
2955
                        return nil
×
2956
                },
×
2957
        )
2958
}
2959

2960
// ChannelView returns the verifiable edge information for each active channel
2961
// within the known channel graph. The set of UTXOs (along with their scripts)
2962
// returned are the ones that need to be watched on chain to detect channel
2963
// closes on the resident blockchain.
2964
//
2965
// NOTE: part of the Store interface.
2966
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2967
        var (
×
2968
                ctx        = context.TODO()
×
2969
                edgePoints []EdgePoint
×
2970
        )
×
2971

×
2972
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2973
                handleChannel := func(_ context.Context,
×
2974
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2975

×
NEW
2976
                        // TODO(elle): update to handle V2 channels.
×
2977
                        pkScript, err := genMultiSigP2WSH(
×
2978
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2979
                        )
×
2980
                        if err != nil {
×
2981
                                return err
×
2982
                        }
×
2983

2984
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2985
                        if err != nil {
×
2986
                                return err
×
2987
                        }
×
2988

2989
                        edgePoints = append(edgePoints, EdgePoint{
×
2990
                                FundingPkScript: pkScript,
×
2991
                                OutPoint:        *op,
×
2992
                        })
×
2993

×
2994
                        return nil
×
2995
                }
2996

2997
                queryFunc := func(ctx context.Context, lastID int64,
×
2998
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2999

×
3000
                        return db.ListChannelsPaginated(
×
3001
                                ctx, sqlc.ListChannelsPaginatedParams{
×
3002
                                        Version: int16(lnwire.GossipVersion1),
×
3003
                                        ID:      lastID,
×
3004
                                        Limit:   limit,
×
3005
                                },
×
3006
                        )
×
3007
                }
×
3008

3009
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
3010
                        return row.ID
×
3011
                }
×
3012

3013
                return sqldb.ExecutePaginatedQuery(
×
3014
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
3015
                        extractCursor, handleChannel,
×
3016
                )
×
3017
        }, func() {
×
3018
                edgePoints = nil
×
3019
        })
×
3020
        if err != nil {
×
3021
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
3022
        }
×
3023

3024
        return edgePoints, nil
×
3025
}
3026

3027
// PruneTip returns the block height and hash of the latest block that has been
3028
// used to prune channels in the graph. Knowing the "prune tip" allows callers
3029
// to tell if the graph is currently in sync with the current best known UTXO
3030
// state.
3031
//
3032
// NOTE: part of the Store interface.
3033
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
3034
        var (
×
3035
                ctx       = context.TODO()
×
3036
                tipHash   chainhash.Hash
×
3037
                tipHeight uint32
×
3038
        )
×
3039
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3040
                pruneTip, err := db.GetPruneTip(ctx)
×
3041
                if errors.Is(err, sql.ErrNoRows) {
×
3042
                        return ErrGraphNeverPruned
×
3043
                } else if err != nil {
×
3044
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
3045
                }
×
3046

3047
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
3048
                tipHeight = uint32(pruneTip.BlockHeight)
×
3049

×
3050
                return nil
×
3051
        }, sqldb.NoOpReset)
3052
        if err != nil {
×
3053
                return nil, 0, err
×
3054
        }
×
3055

3056
        return &tipHash, tipHeight, nil
×
3057
}
3058

3059
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
3060
//
3061
// NOTE: this prunes nodes across protocol versions. It will never prune the
3062
// source nodes.
3063
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
3064
        db SQLQueries) ([]route.Vertex, error) {
×
3065

×
3066
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
3067
        if err != nil {
×
3068
                return nil, fmt.Errorf("unable to delete unconnected "+
×
3069
                        "nodes: %w", err)
×
3070
        }
×
3071

3072
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
3073
        for i, nodeKey := range nodeKeys {
×
3074
                pub, err := route.NewVertexFromBytes(nodeKey)
×
3075
                if err != nil {
×
3076
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
3077
                                "from bytes: %w", err)
×
3078
                }
×
3079

3080
                prunedNodes[i] = pub
×
3081
        }
3082

3083
        return prunedNodes, nil
×
3084
}
3085

3086
// DisconnectBlockAtHeight is used to indicate that the block specified
3087
// by the passed height has been disconnected from the main chain. This
3088
// will "rewind" the graph back to the height below, deleting channels
3089
// that are no longer confirmed from the graph. The prune log will be
3090
// set to the last prune height valid for the remaining chain.
3091
// Channels that were removed from the graph resulting from the
3092
// disconnected block are returned.
3093
//
3094
// NOTE: part of the Store interface.
3095
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
3096
        []*models.ChannelEdgeInfo, error) {
×
3097

×
3098
        ctx := context.TODO()
×
3099

×
3100
        var (
×
3101
                // Every channel having a ShortChannelID starting at 'height'
×
3102
                // will no longer be confirmed.
×
3103
                startShortChanID = lnwire.ShortChannelID{
×
3104
                        BlockHeight: height,
×
3105
                }
×
3106

×
3107
                // Delete everything after this height from the db up until the
×
3108
                // SCID alias range.
×
3109
                endShortChanID = aliasmgr.StartingAlias
×
3110

×
3111
                removedChans []*models.ChannelEdgeInfo
×
3112

×
3113
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
3114
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
3115
        )
×
3116

×
3117
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3118
                rows, err := db.GetChannelsBySCIDRange(
×
3119
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
3120
                                StartScid: chanIDStart,
×
3121
                                EndScid:   chanIDEnd,
×
3122
                        },
×
3123
                )
×
3124
                if err != nil {
×
3125
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
3126
                }
×
3127

3128
                if len(rows) == 0 {
×
3129
                        // No channels to disconnect, but still clean up prune
×
3130
                        // log.
×
3131
                        return db.DeletePruneLogEntriesInRange(
×
3132
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
3133
                                        StartHeight: int64(height),
×
3134
                                        EndHeight: int64(
×
3135
                                                endShortChanID.BlockHeight,
×
3136
                                        ),
×
3137
                                },
×
3138
                        )
×
3139
                }
×
3140

3141
                // Batch build all channel edges for disconnection.
3142
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
3143
                        ctx, s.cfg, db, rows,
×
3144
                )
×
3145
                if err != nil {
×
3146
                        return err
×
3147
                }
×
3148

3149
                removedChans = channelEdges
×
3150

×
3151
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
3152
                if err != nil {
×
3153
                        return fmt.Errorf("unable to delete channels: %w", err)
×
3154
                }
×
3155

3156
                return db.DeletePruneLogEntriesInRange(
×
3157
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
3158
                                StartHeight: int64(height),
×
3159
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
3160
                        },
×
3161
                )
×
3162
        }, func() {
×
3163
                removedChans = nil
×
3164
        })
×
3165
        if err != nil {
×
3166
                return nil, fmt.Errorf("unable to disconnect block at "+
×
3167
                        "height: %w", err)
×
3168
        }
×
3169

3170
        s.cacheMu.Lock()
×
3171
        for _, channel := range removedChans {
×
NEW
3172
                s.rejectCache.remove(channel.Version, channel.ChannelID)
×
3173
                s.chanCache.remove(channel.ChannelID)
×
3174
        }
×
3175
        s.cacheMu.Unlock()
×
3176

×
3177
        return removedChans, nil
×
3178
}
3179

3180
// AddEdgeProof sets the proof of an existing edge in the graph database.
3181
//
3182
// NOTE: part of the Store interface.
3183
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
3184
        proof *models.ChannelAuthProof) error {
×
3185

×
NEW
3186
        if !isKnownGossipVersion(proof.Version) {
×
NEW
3187
                return fmt.Errorf("unsupported gossip version: %d",
×
NEW
3188
                        proof.Version)
×
NEW
3189
        }
×
3190

3191
        var (
×
3192
                ctx       = context.TODO()
×
3193
                scidBytes = channelIDToBytes(scid.ToUint64())
×
3194
        )
×
3195

×
3196
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
NEW
3197
                var (
×
NEW
3198
                        res sql.Result
×
NEW
3199
                        err error
×
3200
                )
×
NEW
3201
                switch proof.Version {
×
NEW
3202
                case lnwire.GossipVersion1:
×
NEW
3203
                        res, err = db.AddV1ChannelProof(
×
NEW
3204
                                ctx, sqlc.AddV1ChannelProofParams{
×
NEW
3205
                                        Scid:              scidBytes,
×
NEW
3206
                                        Node1Signature:    proof.NodeSig1(),
×
NEW
3207
                                        Node2Signature:    proof.NodeSig2(),
×
NEW
3208
                                        Bitcoin1Signature: proof.BitcoinSig1(),
×
NEW
3209
                                        Bitcoin2Signature: proof.BitcoinSig2(),
×
NEW
3210
                                },
×
NEW
3211
                        )
×
3212

NEW
3213
                case lnwire.GossipVersion2:
×
NEW
3214
                        res, err = db.AddV2ChannelProof(
×
NEW
3215
                                ctx, sqlc.AddV2ChannelProofParams{
×
NEW
3216
                                        Scid:      scidBytes,
×
NEW
3217
                                        Signature: proof.Sig(),
×
NEW
3218
                                },
×
NEW
3219
                        )
×
3220

NEW
3221
                default:
×
NEW
3222
                        return fmt.Errorf("unsupported gossip version: %d",
×
NEW
3223
                                proof.Version)
×
3224
                }
3225
                if err != nil {
×
3226
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
3227
                }
×
3228

3229
                n, err := res.RowsAffected()
×
3230
                if err != nil {
×
3231
                        return err
×
3232
                }
×
3233

3234
                if n == 0 {
×
3235
                        return fmt.Errorf("no rows affected when adding edge "+
×
3236
                                "proof for SCID %v", scid)
×
3237
                } else if n > 1 {
×
3238
                        return fmt.Errorf("multiple rows affected when adding "+
×
3239
                                "edge proof for SCID %v: %d rows affected",
×
3240
                                scid, n)
×
3241
                }
×
3242

3243
                return nil
×
3244
        }, sqldb.NoOpReset)
3245
        if err != nil {
×
3246
                return fmt.Errorf("unable to add edge proof: %w", err)
×
3247
        }
×
3248

3249
        return nil
×
3250
}
3251

3252
// PutClosedScid stores a SCID for a closed channel in the database. This is so
3253
// that we can ignore channel announcements that we know to be closed without
3254
// having to validate them and fetch a block.
3255
//
3256
// NOTE: part of the Store interface.
3257
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
3258
        var (
×
3259
                ctx     = context.TODO()
×
3260
                chanIDB = channelIDToBytes(scid.ToUint64())
×
3261
        )
×
3262

×
3263
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3264
                return db.InsertClosedChannel(ctx, chanIDB)
×
3265
        }, sqldb.NoOpReset)
×
3266
}
3267

3268
// IsClosedScid checks whether a channel identified by the passed in scid is
3269
// closed. This helps avoid having to perform expensive validation checks.
3270
//
3271
// NOTE: part of the Store interface.
3272
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
3273
        var (
×
3274
                ctx      = context.TODO()
×
3275
                isClosed bool
×
3276
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
3277
        )
×
3278
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3279
                var err error
×
3280
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3281
                if err != nil {
×
3282
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3283
                                err)
×
3284
                }
×
3285

3286
                return nil
×
3287
        }, sqldb.NoOpReset)
3288
        if err != nil {
×
3289
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3290
                        err)
×
3291
        }
×
3292

3293
        return isClosed, nil
×
3294
}
3295

3296
// GraphSession will provide the call-back with access to a NodeTraverser
3297
// instance which can be used to perform queries against the channel graph.
3298
//
3299
// NOTE: part of the Store interface.
3300
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3301
        reset func()) error {
×
3302

×
3303
        var ctx = context.TODO()
×
3304

×
3305
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3306
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3307
        }, reset)
×
3308
}
3309

3310
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3311
// read only transaction for a consistent view of the graph.
3312
type sqlNodeTraverser struct {
3313
        db    SQLQueries
3314
        chain chainhash.Hash
3315
}
3316

3317
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3318
// NodeTraverser interface.
3319
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3320

3321
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3322
func newSQLNodeTraverser(db SQLQueries,
3323
        chain chainhash.Hash) *sqlNodeTraverser {
×
3324

×
3325
        return &sqlNodeTraverser{
×
3326
                db:    db,
×
3327
                chain: chain,
×
3328
        }
×
3329
}
×
3330

3331
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3332
// node.
3333
//
3334
// NOTE: Part of the NodeTraverser interface.
3335
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3336
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3337

×
3338
        ctx := context.TODO()
×
3339

×
3340
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3341
}
×
3342

3343
// FetchNodeFeatures returns the features of the given node. If the node is
3344
// unknown, assume no additional features are supported.
3345
//
3346
// NOTE: Part of the NodeTraverser interface.
3347
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3348
        *lnwire.FeatureVector, error) {
×
3349

×
3350
        ctx := context.TODO()
×
3351

×
NEW
3352
        return fetchNodeFeatures(ctx, s.db, lnwire.GossipVersion1, nodePub)
×
3353
}
×
3354

3355
// forEachNodeDirectedChannel iterates through all channels of a given
3356
// node, executing the passed callback on the directed edge representing the
3357
// channel and its incoming policy. If the node is not found, no error is
3358
// returned.
3359
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3360
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3361

×
3362
        toNodeCallback := func() route.Vertex {
×
3363
                return nodePub
×
3364
        }
×
3365

3366
        dbID, err := db.GetNodeIDByPubKey(
×
3367
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3368
                        Version: int16(lnwire.GossipVersion1),
×
3369
                        PubKey:  nodePub[:],
×
3370
                },
×
3371
        )
×
3372
        if errors.Is(err, sql.ErrNoRows) {
×
3373
                return nil
×
3374
        } else if err != nil {
×
3375
                return fmt.Errorf("unable to fetch node: %w", err)
×
3376
        }
×
3377

3378
        rows, err := db.ListChannelsByNodeID(
×
3379
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3380
                        Version: int16(lnwire.GossipVersion1),
×
3381
                        NodeID1: dbID,
×
3382
                },
×
3383
        )
×
3384
        if err != nil {
×
3385
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3386
        }
×
3387

3388
        // Exit early if there are no channels for this node so we don't
3389
        // do the unnecessary feature fetching.
3390
        if len(rows) == 0 {
×
3391
                return nil
×
3392
        }
×
3393

3394
        features, err := getNodeFeatures(ctx, db, dbID)
×
3395
        if err != nil {
×
3396
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3397
        }
×
3398

3399
        for _, row := range rows {
×
3400
                node1, node2, err := buildNodeVertices(
×
3401
                        row.Node1Pubkey, row.Node2Pubkey,
×
3402
                )
×
3403
                if err != nil {
×
3404
                        return fmt.Errorf("unable to build node vertices: %w",
×
3405
                                err)
×
3406
                }
×
3407

3408
                edge := buildCacheableChannelInfo(
×
3409
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3410
                        node1, node2,
×
3411
                )
×
3412

×
3413
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3414
                if err != nil {
×
3415
                        return err
×
3416
                }
×
3417

3418
                p1, p2, err := buildCachedChanPolicies(
×
3419
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3420
                )
×
3421
                if err != nil {
×
3422
                        return err
×
3423
                }
×
3424

3425
                // Determine the outgoing and incoming policy for this
3426
                // channel and node combo.
3427
                outPolicy, inPolicy := p1, p2
×
3428
                if p1 != nil && node2 == nodePub {
×
3429
                        outPolicy, inPolicy = p2, p1
×
3430
                } else if p2 != nil && node1 != nodePub {
×
3431
                        outPolicy, inPolicy = p2, p1
×
3432
                }
×
3433

3434
                var cachedInPolicy *models.CachedEdgePolicy
×
3435
                if inPolicy != nil {
×
3436
                        cachedInPolicy = inPolicy
×
3437
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3438
                        cachedInPolicy.ToNodeFeatures = features
×
3439
                }
×
3440

3441
                directedChannel := &DirectedChannel{
×
3442
                        ChannelID:    edge.ChannelID,
×
3443
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3444
                        OtherNode:    edge.NodeKey2Bytes,
×
3445
                        Capacity:     edge.Capacity,
×
3446
                        OutPolicySet: outPolicy != nil,
×
3447
                        InPolicy:     cachedInPolicy,
×
3448
                }
×
3449
                if outPolicy != nil {
×
3450
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3451
                                directedChannel.InboundFee = fee
×
3452
                        })
×
3453
                }
3454

3455
                if nodePub == edge.NodeKey2Bytes {
×
3456
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3457
                }
×
3458

3459
                if err := cb(directedChannel); err != nil {
×
3460
                        return err
×
3461
                }
×
3462
        }
3463

3464
        return nil
×
3465
}
3466

3467
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3468
// and executes the provided callback for each node. It does so via pagination
3469
// along with batch loading of the node feature bits.
3470
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3471
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3472
                features *lnwire.FeatureVector) error) error {
×
3473

×
3474
        handleNode := func(_ context.Context,
×
3475
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3476
                featureBits map[int64][]int) error {
×
3477

×
3478
                fv := lnwire.EmptyFeatureVector()
×
3479
                if features, exists := featureBits[dbNode.ID]; exists {
×
3480
                        for _, bit := range features {
×
3481
                                fv.Set(lnwire.FeatureBit(bit))
×
3482
                        }
×
3483
                }
3484

3485
                var pub route.Vertex
×
3486
                copy(pub[:], dbNode.PubKey)
×
3487

×
3488
                return processNode(dbNode.ID, pub, fv)
×
3489
        }
3490

3491
        queryFunc := func(ctx context.Context, lastID int64,
×
3492
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3493

×
3494
                return db.ListNodeIDsAndPubKeys(
×
3495
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3496
                                Version: int16(lnwire.GossipVersion1),
×
3497
                                ID:      lastID,
×
3498
                                Limit:   limit,
×
3499
                        },
×
3500
                )
×
3501
        }
×
3502

3503
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3504
                return row.ID
×
3505
        }
×
3506

3507
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3508
                return node.ID, nil
×
3509
        }
×
3510

3511
        batchQueryFunc := func(ctx context.Context,
×
3512
                nodeIDs []int64) (map[int64][]int, error) {
×
3513

×
3514
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3515
        }
×
3516

3517
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3518
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3519
                batchQueryFunc, handleNode,
×
3520
        )
×
3521
}
3522

3523
// forEachNodeChannel iterates through all channels of a node, executing
3524
// the passed callback on each. The call-back is provided with the channel's
3525
// edge information, the outgoing policy and the incoming policy for the
3526
// channel and node combo.
3527
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3528
        cfg *SQLStoreConfig, v lnwire.GossipVersion, id int64,
3529
        cb func(*models.ChannelEdgeInfo,
3530
                *models.ChannelEdgePolicy,
3531
                *models.ChannelEdgePolicy) error) error {
×
3532

×
NEW
3533
        // Get all the channels for this node.
×
3534
        rows, err := db.ListChannelsByNodeID(
×
3535
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
3536
                        Version: int16(v),
×
3537
                        NodeID1: id,
×
3538
                },
×
3539
        )
×
3540
        if err != nil {
×
3541
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3542
        }
×
3543

3544
        // Collect all the channel and policy IDs.
3545
        var (
×
3546
                chanIDs   = make([]int64, 0, len(rows))
×
3547
                policyIDs = make([]int64, 0, 2*len(rows))
×
3548
        )
×
3549
        for _, row := range rows {
×
3550
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3551

×
3552
                if row.Policy1ID.Valid {
×
3553
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3554
                }
×
3555
                if row.Policy2ID.Valid {
×
3556
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3557
                }
×
3558
        }
3559

3560
        batchData, err := batchLoadChannelData(
×
3561
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3562
        )
×
3563
        if err != nil {
×
3564
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3565
        }
×
3566

3567
        // Call the call-back for each channel and its known policies.
3568
        for _, row := range rows {
×
3569
                node1, node2, err := buildNodeVertices(
×
3570
                        row.Node1Pubkey, row.Node2Pubkey,
×
3571
                )
×
3572
                if err != nil {
×
3573
                        return fmt.Errorf("unable to build node vertices: %w",
×
3574
                                err)
×
3575
                }
×
3576

3577
                edge, err := buildEdgeInfoWithBatchData(
×
3578
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3579
                        batchData,
×
3580
                )
×
3581
                if err != nil {
×
3582
                        return fmt.Errorf("unable to build channel info: %w",
×
3583
                                err)
×
3584
                }
×
3585

3586
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3587
                if err != nil {
×
3588
                        return fmt.Errorf("unable to extract channel "+
×
3589
                                "policies: %w", err)
×
3590
                }
×
3591

3592
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3593
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3594
                )
×
3595
                if err != nil {
×
3596
                        return fmt.Errorf("unable to build channel "+
×
3597
                                "policies: %w", err)
×
3598
                }
×
3599

3600
                // Determine the outgoing and incoming policy for this
3601
                // channel and node combo.
3602
                p1ToNode := row.GraphChannel.NodeID2
×
3603
                p2ToNode := row.GraphChannel.NodeID1
×
3604
                outPolicy, inPolicy := p1, p2
×
3605
                if (p1 != nil && p1ToNode == id) ||
×
3606
                        (p2 != nil && p2ToNode != id) {
×
3607

×
3608
                        outPolicy, inPolicy = p2, p1
×
3609
                }
×
3610

3611
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3612
                        return err
×
3613
                }
×
3614
        }
3615

3616
        return nil
×
3617
}
3618

3619
// updateChanEdgePolicy upserts the channel policy info we have stored for
3620
// a channel we already know of.
3621
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3622
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3623
        error) {
×
3624

×
3625
        var (
×
3626
                node1Pub, node2Pub route.Vertex
×
3627
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
NEW
3628
                version            = edge.Version
×
3629
        )
×
3630

×
NEW
3631
        if !isKnownGossipVersion(version) {
×
NEW
3632
                return node1Pub, node2Pub, false, fmt.Errorf(
×
NEW
3633
                        "unsupported gossip version: %d", version,
×
NEW
3634
                )
×
NEW
3635
        }
×
3636

3637
        // Check that this edge policy refers to a channel that we already
3638
        // know of. We do this explicitly so that we can return the appropriate
3639
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
3640
        // abort the transaction which would abort the entire batch.
3641
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3642
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3643
                        Scid:    chanIDB,
×
NEW
3644
                        Version: int16(version),
×
3645
                },
×
3646
        )
×
3647
        if errors.Is(err, sql.ErrNoRows) {
×
3648
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3649
        } else if err != nil {
×
3650
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3651
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3652
        }
×
3653

3654
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3655
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3656

×
3657
        // Figure out which node this edge is from.
×
NEW
3658
        isNode1 := edge.IsNode1()
×
3659
        nodeID := dbChan.NodeID1
×
3660
        if !isNode1 {
×
3661
                nodeID = dbChan.NodeID2
×
3662
        }
×
3663

3664
        var (
×
3665
                inboundBase sql.NullInt64
×
3666
                inboundRate sql.NullInt64
×
3667
        )
×
3668
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3669
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3670
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3671
        })
×
3672

NEW
3673
        params := sqlc.UpsertEdgePolicyParams{
×
NEW
3674
                Version:                 int16(version),
×
NEW
3675
                ChannelID:               dbChan.ID,
×
NEW
3676
                NodeID:                  nodeID,
×
NEW
3677
                Timelock:                int32(edge.TimeLockDelta),
×
NEW
3678
                FeePpm:                  int64(edge.FeeProportionalMillionths),
×
NEW
3679
                BaseFeeMsat:             int64(edge.FeeBaseMSat),
×
NEW
3680
                MinHtlcMsat:             int64(edge.MinHTLC),
×
3681
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3682
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3683
                InboundBaseFeeMsat:      inboundBase,
×
3684
                InboundFeeRateMilliMsat: inboundRate,
×
3685
                Signature:               edge.SigBytes,
×
NEW
3686
        }
×
NEW
3687

×
NEW
3688
        if version == lnwire.GossipVersion1 {
×
NEW
3689
                params.LastUpdate = sqldb.SQLInt64(edge.LastUpdate.Unix())
×
NEW
3690
                params.Disabled = sql.NullBool{
×
NEW
3691
                        Valid: true,
×
NEW
3692
                        Bool:  edge.IsDisabled(),
×
NEW
3693
                }
×
NEW
3694
                params.MaxHtlcMsat = sql.NullInt64{
×
NEW
3695
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
NEW
3696
                        Int64: int64(edge.MaxHTLC),
×
NEW
3697
                }
×
NEW
3698
        } else {
×
NEW
3699
                params.BlockHeight = sqldb.SQLInt64(
×
NEW
3700
                        int64(edge.LastBlockHeight),
×
NEW
3701
                )
×
NEW
3702
                params.DisableFlags = sqldb.SQLInt16(edge.DisableFlags)
×
NEW
3703
                params.MaxHtlcMsat = sqldb.SQLInt64(int64(edge.MaxHTLC))
×
NEW
3704
        }
×
3705

NEW
3706
        id, err := tx.UpsertEdgePolicy(ctx, params)
×
3707
        if err != nil {
×
3708
                return node1Pub, node2Pub, isNode1,
×
3709
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3710
        }
×
3711

3712
        // Convert the flat extra opaque data into a map of TLV types to
3713
        // values.
NEW
3714
        extra := edge.ExtraSignedFields
×
NEW
3715
        if version == lnwire.GossipVersion1 {
×
NEW
3716
                extra, err = marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
NEW
3717
                if err != nil {
×
NEW
3718
                        return node1Pub, node2Pub, false, fmt.Errorf(
×
NEW
3719
                                "unable to marshal extra opaque data: %w", err,
×
NEW
3720
                        )
×
NEW
3721
                }
×
3722
        }
3723

3724
        // Update the channel policy's extra signed fields.
3725
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3726
        if err != nil {
×
3727
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3728
                        "policy extra TLVs: %w", err)
×
3729
        }
×
3730

3731
        return node1Pub, node2Pub, isNode1, nil
×
3732
}
3733

3734
// getNodeByPubKey attempts to look up a target node by its public key.
3735
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3736
        v lnwire.GossipVersion, pubKey route.Vertex) (int64, *models.Node,
NEW
3737
        error) {
×
3738

×
3739
        dbNode, err := db.GetNodeByPubKey(
×
3740
                ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
3741
                        Version: int16(v),
×
3742
                        PubKey:  pubKey[:],
×
3743
                },
×
3744
        )
×
3745
        if errors.Is(err, sql.ErrNoRows) {
×
3746
                return 0, nil, ErrGraphNodeNotFound
×
3747
        } else if err != nil {
×
3748
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3749
        }
×
3750

3751
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3752
        if err != nil {
×
3753
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3754
        }
×
3755

3756
        return dbNode.ID, node, nil
×
3757
}
3758

3759
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3760
// provided parameters.
3761
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3762
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3763

×
3764
        return &models.CachedEdgeInfo{
×
3765
                ChannelID:     byteOrder.Uint64(scid),
×
3766
                NodeKey1Bytes: node1Pub,
×
3767
                NodeKey2Bytes: node2Pub,
×
3768
                Capacity:      btcutil.Amount(capacity),
×
3769
        }
×
3770
}
×
3771

3772
// buildNode constructs a Node instance from the given database node
3773
// record. The node's features, addresses and extra signed fields are also
3774
// fetched from the database and set on the node.
3775
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3776
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3777

×
3778
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3779
        if err != nil {
×
3780
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3781
                        err)
×
3782
        }
×
3783

3784
        return buildNodeWithBatchData(dbNode, data)
×
3785
}
3786

3787
// isKnownGossipVersion checks whether the provided gossip version is known
3788
// and supported.
NEW
3789
func isKnownGossipVersion(v lnwire.GossipVersion) bool {
×
NEW
3790
        switch v {
×
NEW
3791
        case lnwire.GossipVersion1:
×
NEW
3792
                return true
×
NEW
3793
        case lnwire.GossipVersion2:
×
NEW
3794
                return true
×
NEW
3795
        default:
×
NEW
3796
                return false
×
3797
        }
3798
}
3799

3800
// buildNodeWithBatchData builds a models.Node instance
3801
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3802
// features/addresses/extra fields, then the corresponding fields are expected
3803
// to be present in the batchNodeData.
3804
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3805
        batchData *batchNodeData) (*models.Node, error) {
×
3806

×
NEW
3807
        v := lnwire.GossipVersion(dbNode.Version)
×
NEW
3808

×
NEW
3809
        if !isKnownGossipVersion(v) {
×
NEW
3810
                return nil, fmt.Errorf("unknown node version: %d", v)
×
UNCOV
3811
        }
×
3812

NEW
3813
        pub, err := route.NewVertexFromBytes(dbNode.PubKey)
×
NEW
3814
        if err != nil {
×
NEW
3815
                return nil, fmt.Errorf("unable to parse pubkey: %w", err)
×
NEW
3816
        }
×
3817

NEW
3818
        node := models.NewShellNode(v, pub)
×
3819

×
3820
        if len(dbNode.Signature) == 0 {
×
3821
                return node, nil
×
3822
        }
×
3823

3824
        node.AuthSigBytes = dbNode.Signature
×
3825

×
3826
        if dbNode.Alias.Valid {
×
3827
                node.Alias = fn.Some(dbNode.Alias.String)
×
3828
        }
×
3829
        if dbNode.LastUpdate.Valid {
×
3830
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3831
        }
×
NEW
3832
        if dbNode.BlockHeight.Valid {
×
NEW
3833
                node.LastBlockHeight = uint32(dbNode.BlockHeight.Int64)
×
NEW
3834
        }
×
3835

3836
        if dbNode.Color.Valid {
×
3837
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3838
                if err != nil {
×
3839
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3840
                                err)
×
3841
                }
×
3842

3843
                node.Color = fn.Some(nodeColor)
×
3844
        }
3845

3846
        // Use preloaded features.
3847
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3848
                fv := lnwire.EmptyFeatureVector()
×
3849
                for _, bit := range features {
×
3850
                        fv.Set(lnwire.FeatureBit(bit))
×
3851
                }
×
3852
                node.Features = fv
×
3853
        }
3854

3855
        // Use preloaded addresses.
3856
        addresses, exists := batchData.addresses[dbNode.ID]
×
3857
        if exists && len(addresses) > 0 {
×
3858
                node.Addresses, err = buildNodeAddresses(addresses)
×
3859
                if err != nil {
×
3860
                        return nil, fmt.Errorf("unable to build addresses "+
×
3861
                                "for node(%d): %w", dbNode.ID, err)
×
3862
                }
×
3863
        }
3864

3865
        // Use preloaded extra fields.
3866
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
NEW
3867
                if v == lnwire.GossipVersion1 {
×
NEW
3868
                        records := lnwire.CustomRecords(extraFields)
×
NEW
3869
                        recs, err := records.Serialize()
×
NEW
3870
                        if err != nil {
×
NEW
3871
                                return nil, fmt.Errorf("unable to serialize "+
×
NEW
3872
                                        "extra signed fields: %w", err)
×
NEW
3873
                        }
×
3874

NEW
3875
                        if len(recs) != 0 {
×
NEW
3876
                                node.ExtraOpaqueData = recs
×
NEW
3877
                        }
×
NEW
3878
                } else if len(extraFields) > 0 {
×
NEW
3879
                        node.ExtraSignedFields = extraFields
×
UNCOV
3880
                }
×
3881
        }
3882

3883
        return node, nil
×
3884
}
3885

3886
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3887
// with the preloaded data, and executes the provided callback for each node.
3888
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3889
        db SQLQueries, nodes []sqlc.GraphNode,
3890
        cb func(dbID int64, node *models.Node) error) error {
×
3891

×
3892
        // Extract node IDs for batch loading.
×
3893
        nodeIDs := make([]int64, len(nodes))
×
3894
        for i, node := range nodes {
×
3895
                nodeIDs[i] = node.ID
×
3896
        }
×
3897

3898
        // Batch load all related data for this page.
3899
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3900
        if err != nil {
×
3901
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3902
        }
×
3903

3904
        for _, dbNode := range nodes {
×
3905
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3906
                if err != nil {
×
3907
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3908
                                dbNode.ID, err)
×
3909
                }
×
3910

3911
                if err := cb(dbNode.ID, node); err != nil {
×
3912
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3913
                                dbNode.ID, err)
×
3914
                }
×
3915
        }
3916

3917
        return nil
×
3918
}
3919

3920
// getNodeFeatures fetches the feature bits and constructs the feature vector
3921
// for a node with the given DB ID.
3922
func getNodeFeatures(ctx context.Context, db SQLQueries,
3923
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3924

×
3925
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3926
        if err != nil {
×
3927
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3928
                        nodeID, err)
×
3929
        }
×
3930

3931
        features := lnwire.EmptyFeatureVector()
×
3932
        for _, feature := range rows {
×
3933
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3934
        }
×
3935

3936
        return features, nil
×
3937
}
3938

3939
// upsertNodeAncillaryData updates the node's features, addresses, and extra
3940
// signed fields. This is common logic shared by upsertNode and
3941
// upsertSourceNode.
3942
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
3943
        nodeID int64, node *models.Node) error {
×
3944

×
3945
        // Update the node's features.
×
3946
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3947
        if err != nil {
×
3948
                return fmt.Errorf("inserting node features: %w", err)
×
3949
        }
×
3950

3951
        // Update the node's addresses.
3952
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3953
        if err != nil {
×
3954
                return fmt.Errorf("inserting node addresses: %w", err)
×
3955
        }
×
3956

3957
        // Convert the flat extra opaque data into a map of TLV types to
3958
        // values.
NEW
3959
        extra := node.ExtraSignedFields
×
NEW
3960
        if node.Version == lnwire.GossipVersion1 {
×
NEW
3961
                extra, err = marshalExtraOpaqueData(node.ExtraOpaqueData)
×
NEW
3962
                if err != nil {
×
NEW
3963
                        return fmt.Errorf("unable to marshal extra opaque "+
×
NEW
3964
                                "data: %w", err)
×
NEW
3965
                }
×
3966
        }
3967

3968
        // Update the node's extra signed fields.
3969
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3970
        if err != nil {
×
3971
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
3972
        }
×
3973

3974
        return nil
×
3975
}
3976

3977
// populateNodeParams populates the common node parameters from a models.Node.
3978
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
3979
func populateNodeParams(node *models.Node,
3980
        setParams func(lastUpdate, lastBlockHeight sql.NullInt64, alias,
3981
                colorStr sql.NullString, signature []byte)) error {
×
3982

×
3983
        if !node.HaveAnnouncement() {
×
3984
                return nil
×
3985
        }
×
3986

NEW
3987
        var (
×
NEW
3988
                alias, colorStr             sql.NullString
×
NEW
3989
                lastUpdate, lastBlockHeight sql.NullInt64
×
NEW
3990
        )
×
NEW
3991
        node.Color.WhenSome(func(rgba color.RGBA) {
×
NEW
3992
                colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
NEW
3993
        })
×
NEW
3994
        node.Alias.WhenSome(func(s string) {
×
NEW
3995
                alias = sqldb.SQLStrValid(s)
×
NEW
3996
        })
×
3997

3998
        switch node.Version {
×
3999
        case lnwire.GossipVersion1:
×
NEW
4000
                lastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
×
4001

4002
        case lnwire.GossipVersion2:
×
NEW
4003
                lastBlockHeight = sqldb.SQLInt64(int64(node.LastBlockHeight))
×
4004

4005
        default:
×
4006
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
4007
        }
4008

NEW
4009
        setParams(
×
NEW
4010
                lastUpdate, lastBlockHeight, alias, colorStr, node.AuthSigBytes,
×
NEW
4011
        )
×
NEW
4012

×
UNCOV
4013
        return nil
×
4014
}
4015

4016
// buildNodeUpsertParams builds the parameters for upserting a node using the
4017
// strict UpsertNode query (requires timestamp to be increasing).
4018
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
4019
        params := sqlc.UpsertNodeParams{
×
NEW
4020
                Version: int16(node.Version),
×
4021
                PubKey:  node.PubKeyBytes[:],
×
4022
        }
×
4023

×
4024
        err := populateNodeParams(
×
NEW
4025
                node, func(lastUpdate, lastBlockHeight sql.NullInt64, alias,
×
4026
                        colorStr sql.NullString,
×
4027
                        signature []byte) {
×
4028

×
4029
                        params.LastUpdate = lastUpdate
×
NEW
4030
                        params.BlockHeight = lastBlockHeight
×
4031
                        params.Alias = alias
×
4032
                        params.Color = colorStr
×
4033
                        params.Signature = signature
×
NEW
4034
                },
×
4035
        )
4036

4037
        return params, err
×
4038
}
4039

4040
// buildSourceNodeUpsertParams builds the parameters for upserting the source
4041
// node using the lenient UpsertSourceNode query (allows same timestamp).
4042
func buildSourceNodeUpsertParams(node *models.Node) (
4043
        sqlc.UpsertSourceNodeParams, error) {
×
4044

×
4045
        params := sqlc.UpsertSourceNodeParams{
×
NEW
4046
                Version: int16(node.Version),
×
4047
                PubKey:  node.PubKeyBytes[:],
×
4048
        }
×
4049

×
4050
        err := populateNodeParams(
×
NEW
4051
                node, func(lastUpdate, lastBlock sql.NullInt64, alias,
×
4052
                        colorStr sql.NullString, signature []byte) {
×
4053

×
NEW
4054
                        params.BlockHeight = lastBlock
×
4055
                        params.LastUpdate = lastUpdate
×
4056
                        params.Alias = alias
×
4057
                        params.Color = colorStr
×
4058
                        params.Signature = signature
×
4059
                },
×
4060
        )
4061

4062
        return params, err
×
4063
}
4064

4065
// upsertSourceNode upserts the source node record into the database using a
4066
// less strict upsert that allows updates even when the timestamp hasn't
4067
// changed. This is necessary to handle concurrent updates to our own node
4068
// during startup and runtime. The node's features, addresses and extra TLV
4069
// types are also updated. The node's DB ID is returned.
4070
func upsertSourceNode(ctx context.Context, db SQLQueries,
4071
        node *models.Node) (int64, error) {
×
4072

×
4073
        params, err := buildSourceNodeUpsertParams(node)
×
4074
        if err != nil {
×
4075
                return 0, err
×
4076
        }
×
4077

4078
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
4079
        if err != nil {
×
4080
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
4081
                        node.PubKeyBytes, err)
×
4082
        }
×
4083

4084
        // We can exit here if we don't have the announcement yet.
4085
        if !node.HaveAnnouncement() {
×
4086
                return nodeID, nil
×
4087
        }
×
4088

4089
        // Update the ancillary node data (features, addresses, extra fields).
4090
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
4091
        if err != nil {
×
4092
                return 0, err
×
4093
        }
×
4094

4095
        return nodeID, nil
×
4096
}
4097

4098
// upsertNode upserts the node record into the database. If the node already
4099
// exists, then the node's information is updated. If the node doesn't exist,
4100
// then a new node is created. The node's features, addresses and extra TLV
4101
// types are also updated. The node's DB ID is returned.
4102
func upsertNode(ctx context.Context, db SQLQueries,
4103
        node *models.Node) (int64, error) {
×
4104

×
NEW
4105
        if !isKnownGossipVersion(node.Version) {
×
NEW
4106
                return 0, fmt.Errorf("unknown gossip version: %d", node.Version)
×
NEW
4107
        }
×
4108

4109
        params, err := buildNodeUpsertParams(node)
×
4110
        if err != nil {
×
4111
                return 0, err
×
4112
        }
×
4113

4114
        nodeID, err := db.UpsertNode(ctx, params)
×
4115
        if err != nil {
×
4116
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
4117
                        err)
×
4118
        }
×
4119

4120
        // We can exit here if we don't have the announcement yet.
4121
        if !node.HaveAnnouncement() {
×
4122
                return nodeID, nil
×
4123
        }
×
4124

4125
        // Update the ancillary node data (features, addresses, extra fields).
4126
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
4127
        if err != nil {
×
4128
                return 0, err
×
4129
        }
×
4130

4131
        return nodeID, nil
×
4132
}
4133

4134
// upsertNodeFeatures updates the node's features node_features table. This
4135
// includes deleting any feature bits no longer present and inserting any new
4136
// feature bits. If the feature bit does not yet exist in the features table,
4137
// then an entry is created in that table first.
4138
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
4139
        features *lnwire.FeatureVector) error {
×
4140

×
4141
        // Get any existing features for the node.
×
4142
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
4143
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
4144
                return err
×
4145
        }
×
4146

4147
        // Copy the nodes latest set of feature bits.
4148
        newFeatures := make(map[int32]struct{})
×
4149
        if features != nil {
×
4150
                for feature := range features.Features() {
×
4151
                        newFeatures[int32(feature)] = struct{}{}
×
4152
                }
×
4153
        }
4154

4155
        // For any current feature that already exists in the DB, remove it from
4156
        // the in-memory map. For any existing feature that does not exist in
4157
        // the in-memory map, delete it from the database.
4158
        for _, feature := range existingFeatures {
×
4159
                // The feature is still present, so there are no updates to be
×
4160
                // made.
×
4161
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
4162
                        delete(newFeatures, feature.FeatureBit)
×
4163
                        continue
×
4164
                }
4165

4166
                // The feature is no longer present, so we remove it from the
4167
                // database.
4168
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
4169
                        NodeID:     nodeID,
×
4170
                        FeatureBit: feature.FeatureBit,
×
4171
                })
×
4172
                if err != nil {
×
4173
                        return fmt.Errorf("unable to delete node(%d) "+
×
4174
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
4175
                                err)
×
4176
                }
×
4177
        }
4178

4179
        // Any remaining entries in newFeatures are new features that need to be
4180
        // added to the database for the first time.
4181
        for feature := range newFeatures {
×
4182
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
4183
                        NodeID:     nodeID,
×
4184
                        FeatureBit: feature,
×
4185
                })
×
4186
                if err != nil {
×
4187
                        return fmt.Errorf("unable to insert node(%d) "+
×
4188
                                "feature(%v): %w", nodeID, feature, err)
×
4189
                }
×
4190
        }
4191

4192
        return nil
×
4193
}
4194

4195
// fetchNodeFeatures fetches the features for a node with the given public key.
4196
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
4197
        v lnwire.GossipVersion, nodePub route.Vertex) (*lnwire.FeatureVector,
NEW
4198
        error) {
×
4199

×
4200
        rows, err := queries.GetNodeFeaturesByPubKey(
×
4201
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
4202
                        PubKey:  nodePub[:],
×
NEW
4203
                        Version: int16(v),
×
4204
                },
×
4205
        )
×
4206
        if err != nil {
×
4207
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
4208
                        nodePub, err)
×
4209
        }
×
4210

4211
        features := lnwire.EmptyFeatureVector()
×
4212
        for _, bit := range rows {
×
4213
                features.Set(lnwire.FeatureBit(bit))
×
4214
        }
×
4215

4216
        return features, nil
×
4217
}
4218

4219
// dbAddressType is an enum type that represents the different address types
4220
// that we store in the node_addresses table. The address type determines how
4221
// the address is to be serialised/deserialize.
4222
type dbAddressType uint8
4223

4224
const (
4225
        addressTypeIPv4   dbAddressType = 1
4226
        addressTypeIPv6   dbAddressType = 2
4227
        addressTypeTorV2  dbAddressType = 3
4228
        addressTypeTorV3  dbAddressType = 4
4229
        addressTypeDNS    dbAddressType = 5
4230
        addressTypeOpaque dbAddressType = math.MaxInt8
4231
)
4232

4233
// collectAddressRecords collects the addresses from the provided
4234
// net.Addr slice and returns a map of dbAddressType to a slice of address
4235
// strings.
4236
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
4237
        error) {
×
4238

×
4239
        // Copy the nodes latest set of addresses.
×
4240
        newAddresses := map[dbAddressType][]string{
×
4241
                addressTypeIPv4:   {},
×
4242
                addressTypeIPv6:   {},
×
4243
                addressTypeTorV2:  {},
×
4244
                addressTypeTorV3:  {},
×
4245
                addressTypeDNS:    {},
×
4246
                addressTypeOpaque: {},
×
4247
        }
×
4248
        addAddr := func(t dbAddressType, addr net.Addr) {
×
4249
                newAddresses[t] = append(newAddresses[t], addr.String())
×
4250
        }
×
4251

4252
        for _, address := range addresses {
×
4253
                switch addr := address.(type) {
×
4254
                case *net.TCPAddr:
×
4255
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
4256
                                addAddr(addressTypeIPv4, addr)
×
4257
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
4258
                                addAddr(addressTypeIPv6, addr)
×
4259
                        } else {
×
4260
                                return nil, fmt.Errorf("unhandled IP "+
×
4261
                                        "address: %v", addr)
×
4262
                        }
×
4263

4264
                case *tor.OnionAddr:
×
4265
                        switch len(addr.OnionService) {
×
4266
                        case tor.V2Len:
×
4267
                                addAddr(addressTypeTorV2, addr)
×
4268
                        case tor.V3Len:
×
4269
                                addAddr(addressTypeTorV3, addr)
×
4270
                        default:
×
4271
                                return nil, fmt.Errorf("invalid length for " +
×
4272
                                        "a tor address")
×
4273
                        }
4274

4275
                case *lnwire.DNSAddress:
×
4276
                        addAddr(addressTypeDNS, addr)
×
4277

4278
                case *lnwire.OpaqueAddrs:
×
4279
                        addAddr(addressTypeOpaque, addr)
×
4280

4281
                default:
×
4282
                        return nil, fmt.Errorf("unhandled address type: %T",
×
4283
                                addr)
×
4284
                }
4285
        }
4286

4287
        return newAddresses, nil
×
4288
}
4289

4290
// upsertNodeAddresses updates the node's addresses in the database. This
4291
// includes deleting any existing addresses and inserting the new set of
4292
// addresses. The deletion is necessary since the ordering of the addresses may
4293
// change, and we need to ensure that the database reflects the latest set of
4294
// addresses so that at the time of reconstructing the node announcement, the
4295
// order is preserved and the signature over the message remains valid.
4296
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
4297
        addresses []net.Addr) error {
×
4298

×
4299
        // Delete any existing addresses for the node. This is required since
×
4300
        // even if the new set of addresses is the same, the ordering may have
×
4301
        // changed for a given address type.
×
4302
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
4303
        if err != nil {
×
4304
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
4305
                        nodeID, err)
×
4306
        }
×
4307

4308
        newAddresses, err := collectAddressRecords(addresses)
×
4309
        if err != nil {
×
4310
                return err
×
4311
        }
×
4312

4313
        // Any remaining entries in newAddresses are new addresses that need to
4314
        // be added to the database for the first time.
4315
        for addrType, addrList := range newAddresses {
×
4316
                for position, addr := range addrList {
×
4317
                        err := db.UpsertNodeAddress(
×
4318
                                ctx, sqlc.UpsertNodeAddressParams{
×
4319
                                        NodeID:   nodeID,
×
4320
                                        Type:     int16(addrType),
×
4321
                                        Address:  addr,
×
4322
                                        Position: int32(position),
×
4323
                                },
×
4324
                        )
×
4325
                        if err != nil {
×
4326
                                return fmt.Errorf("unable to insert "+
×
4327
                                        "node(%d) address(%v): %w", nodeID,
×
4328
                                        addr, err)
×
4329
                        }
×
4330
                }
4331
        }
4332

4333
        return nil
×
4334
}
4335

4336
// getNodeAddresses fetches the addresses for a node with the given DB ID.
4337
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
4338
        error) {
×
4339

×
4340
        // GetNodeAddresses ensures that the addresses for a given type are
×
4341
        // returned in the same order as they were inserted.
×
4342
        rows, err := db.GetNodeAddresses(ctx, id)
×
4343
        if err != nil {
×
4344
                return nil, err
×
4345
        }
×
4346

4347
        addresses := make([]net.Addr, 0, len(rows))
×
4348
        for _, row := range rows {
×
4349
                address := row.Address
×
4350

×
4351
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4352
                if err != nil {
×
4353
                        return nil, fmt.Errorf("unable to parse address "+
×
4354
                                "for node(%d): %v: %w", id, address, err)
×
4355
                }
×
4356

4357
                addresses = append(addresses, addr)
×
4358
        }
4359

4360
        // If we have no addresses, then we'll return nil instead of an
4361
        // empty slice.
4362
        if len(addresses) == 0 {
×
4363
                addresses = nil
×
4364
        }
×
4365

4366
        return addresses, nil
×
4367
}
4368

4369
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4370
// database. This includes updating any existing types, inserting any new types,
4371
// and deleting any types that are no longer present.
4372
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4373
        nodeID int64, extraFields map[uint64][]byte) error {
×
4374

×
4375
        // Get any existing extra signed fields for the node.
×
4376
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4377
        if err != nil {
×
4378
                return err
×
4379
        }
×
4380

4381
        // Make a lookup map of the existing field types so that we can use it
4382
        // to keep track of any fields we should delete.
4383
        m := make(map[uint64]bool)
×
4384
        for _, field := range existingFields {
×
4385
                m[uint64(field.Type)] = true
×
4386
        }
×
4387

4388
        // For all the new fields, we'll upsert them and remove them from the
4389
        // map of existing fields.
4390
        for tlvType, value := range extraFields {
×
4391
                err = db.UpsertNodeExtraType(
×
4392
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4393
                                NodeID: nodeID,
×
4394
                                Type:   int64(tlvType),
×
4395
                                Value:  value,
×
4396
                        },
×
4397
                )
×
4398
                if err != nil {
×
4399
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4400
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4401
                }
×
4402

4403
                // Remove the field from the map of existing fields if it was
4404
                // present.
4405
                delete(m, tlvType)
×
4406
        }
4407

4408
        // For all the fields that are left in the map of existing fields, we'll
4409
        // delete them as they are no longer present in the new set of fields.
4410
        for tlvType := range m {
×
4411
                err = db.DeleteExtraNodeType(
×
4412
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4413
                                NodeID: nodeID,
×
4414
                                Type:   int64(tlvType),
×
4415
                        },
×
4416
                )
×
4417
                if err != nil {
×
4418
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4419
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4420
                }
×
4421
        }
4422

4423
        return nil
×
4424
}
4425

4426
// srcNodeInfo holds the information about the source node of the graph.
4427
type srcNodeInfo struct {
4428
        // id is the DB level ID of the source node entry in the "nodes" table.
4429
        id int64
4430

4431
        // pub is the public key of the source node.
4432
        pub route.Vertex
4433
}
4434

4435
// sourceNode returns the DB node ID and pub key of the source node for the
4436
// specified protocol version.
4437
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4438
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4439

×
4440
        s.srcNodeMu.Lock()
×
4441
        defer s.srcNodeMu.Unlock()
×
4442

×
4443
        // If we already have the source node ID and pub key cached, then
×
4444
        // return them.
×
4445
        if info, ok := s.srcNodes[version]; ok {
×
4446
                return info.id, info.pub, nil
×
4447
        }
×
4448

4449
        var pubKey route.Vertex
×
4450

×
4451
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4452
        if err != nil {
×
4453
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4454
                        err)
×
4455
        }
×
4456

4457
        if len(nodes) == 0 {
×
4458
                return 0, pubKey, ErrSourceNodeNotSet
×
4459
        } else if len(nodes) > 1 {
×
4460
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4461
                        "protocol %s found", version)
×
4462
        }
×
4463

4464
        copy(pubKey[:], nodes[0].PubKey)
×
4465

×
4466
        s.srcNodes[version] = &srcNodeInfo{
×
4467
                id:  nodes[0].NodeID,
×
4468
                pub: pubKey,
×
4469
        }
×
4470

×
4471
        return nodes[0].NodeID, pubKey, nil
×
4472
}
4473

4474
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4475
// This then produces a map from TLV type to value. If the input is not a
4476
// valid TLV stream, then an error is returned.
4477
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4478
        r := bytes.NewReader(data)
×
4479

×
4480
        tlvStream, err := tlv.NewStream()
×
4481
        if err != nil {
×
4482
                return nil, err
×
4483
        }
×
4484

4485
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4486
        // pass it into the P2P decoding variant.
4487
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4488
        if err != nil {
×
4489
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4490
        }
×
4491
        if len(parsedTypes) == 0 {
×
4492
                return nil, nil
×
4493
        }
×
4494

4495
        records := make(map[uint64][]byte)
×
4496
        for k, v := range parsedTypes {
×
4497
                records[uint64(k)] = v
×
4498
        }
×
4499

4500
        return records, nil
×
4501
}
4502

4503
// insertChannel inserts a new channel record into the database.
4504
func insertChannel(ctx context.Context, db SQLQueries,
4505
        edge *models.ChannelEdgeInfo) error {
×
4506

×
NEW
4507
        v := edge.Version
×
NEW
4508

×
4509
        // Make sure that at least a "shell" entry for each node is present in
×
4510
        // the nodes table.
×
NEW
4511
        node1DBID, err := maybeCreateShellNode(ctx, db, v, edge.NodeKey1Bytes)
×
4512
        if err != nil {
×
4513
                return fmt.Errorf("unable to create shell node: %w", err)
×
4514
        }
×
4515

NEW
4516
        node2DBID, err := maybeCreateShellNode(ctx, db, v, edge.NodeKey2Bytes)
×
4517
        if err != nil {
×
4518
                return fmt.Errorf("unable to create shell node: %w", err)
×
4519
        }
×
4520

4521
        var capacity sql.NullInt64
×
4522
        if edge.Capacity != 0 {
×
4523
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4524
        }
×
4525

4526
        createParams := sqlc.CreateChannelParams{
×
NEW
4527
                Version:  int16(v),
×
NEW
4528
                Scid:     channelIDToBytes(edge.ChannelID),
×
NEW
4529
                NodeID1:  node1DBID,
×
NEW
4530
                NodeID2:  node2DBID,
×
NEW
4531
                Outpoint: edge.ChannelPoint.String(),
×
NEW
4532
                Capacity: capacity,
×
NEW
4533
        }
×
NEW
4534
        edge.BitcoinKey1Bytes.WhenSome(func(vertex route.Vertex) {
×
NEW
4535
                createParams.BitcoinKey1 = vertex[:]
×
NEW
4536
        })
×
NEW
4537
        edge.BitcoinKey2Bytes.WhenSome(func(vertex route.Vertex) {
×
NEW
4538
                createParams.BitcoinKey2 = vertex[:]
×
NEW
4539
        })
×
NEW
4540
        edge.FundingScript.WhenSome(func(script []byte) {
×
NEW
4541
                createParams.FundingPkScript = script
×
NEW
4542
        })
×
NEW
4543
        edge.MerkleRootHash.WhenSome(func(hash chainhash.Hash) {
×
NEW
4544
                createParams.MerkleRootHash = hash[:]
×
NEW
4545
        })
×
4546

4547
        if edge.AuthProof != nil {
×
4548
                proof := edge.AuthProof
×
4549

×
NEW
4550
                createParams.Node1Signature = proof.NodeSig1()
×
NEW
4551
                createParams.Node2Signature = proof.NodeSig2()
×
NEW
4552
                createParams.Bitcoin1Signature = proof.BitcoinSig1()
×
NEW
4553
                createParams.Bitcoin2Signature = proof.BitcoinSig2()
×
NEW
4554
                createParams.Signature = proof.Sig()
×
UNCOV
4555
        }
×
4556

4557
        // Insert the new channel record.
4558
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4559
        if err != nil {
×
4560
                return err
×
4561
        }
×
4562

4563
        // Insert any channel features.
4564
        for feature := range edge.Features.Features() {
×
4565
                err = db.InsertChannelFeature(
×
4566
                        ctx, sqlc.InsertChannelFeatureParams{
×
4567
                                ChannelID:  dbChanID,
×
4568
                                FeatureBit: int32(feature),
×
4569
                        },
×
4570
                )
×
4571
                if err != nil {
×
4572
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4573
                                "feature(%v): %w", dbChanID, feature, err)
×
4574
                }
×
4575
        }
4576

4577
        // Finally, insert any extra TLV fields in the channel announcement.
NEW
4578
        extra := edge.ExtraSignedFields
×
NEW
4579
        if v == lnwire.GossipVersion1 {
×
NEW
4580
                extra, err = marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
NEW
4581
                if err != nil {
×
NEW
4582
                        return fmt.Errorf("unable to marshal extra opaque "+
×
NEW
4583
                                "data: %w", err)
×
NEW
4584
                }
×
4585
        }
4586

4587
        for tlvType, value := range extra {
×
4588
                err := db.UpsertChannelExtraType(
×
4589
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4590
                                ChannelID: dbChanID,
×
4591
                                Type:      int64(tlvType),
×
4592
                                Value:     value,
×
4593
                        },
×
4594
                )
×
4595
                if err != nil {
×
4596
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4597
                                "extra signed field(%v): %w", edge.ChannelID,
×
4598
                                tlvType, err)
×
4599
                }
×
4600
        }
4601

4602
        return nil
×
4603
}
4604

4605
// maybeCreateShellNode checks if a shell node entry exists for the
4606
// given public key. If it does not exist, then a new shell node entry is
4607
// created. The ID of the node is returned. A shell node only has a protocol
4608
// version and public key persisted.
4609
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
NEW
4610
        v lnwire.GossipVersion, pubKey route.Vertex) (int64, error) {
×
4611

×
4612
        dbNode, err := db.GetNodeByPubKey(
×
4613
                ctx, sqlc.GetNodeByPubKeyParams{
×
4614
                        PubKey:  pubKey[:],
×
NEW
4615
                        Version: int16(v),
×
4616
                },
×
4617
        )
×
4618
        // The node exists. Return the ID.
×
4619
        if err == nil {
×
4620
                return dbNode.ID, nil
×
4621
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4622
                return 0, err
×
4623
        }
×
4624

4625
        // Otherwise, the node does not exist, so we create a shell entry for
4626
        // it.
4627
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
NEW
4628
                Version: int16(v),
×
4629
                PubKey:  pubKey[:],
×
4630
        })
×
4631
        if err != nil {
×
4632
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4633
        }
×
4634

4635
        return id, nil
×
4636
}
4637

4638
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4639
// the database. This includes deleting any existing types and then inserting
4640
// the new types.
4641
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4642
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4643

×
4644
        // Delete all existing extra signed fields for the channel policy.
×
4645
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4646
        if err != nil {
×
4647
                return fmt.Errorf("unable to delete "+
×
4648
                        "existing policy extra signed fields for policy %d: %w",
×
4649
                        chanPolicyID, err)
×
4650
        }
×
4651

4652
        // Insert all new extra signed fields for the channel policy.
4653
        for tlvType, value := range extraFields {
×
4654
                err = db.UpsertChanPolicyExtraType(
×
4655
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4656
                                ChannelPolicyID: chanPolicyID,
×
4657
                                Type:            int64(tlvType),
×
4658
                                Value:           value,
×
4659
                        },
×
4660
                )
×
4661
                if err != nil {
×
4662
                        return fmt.Errorf("unable to insert "+
×
4663
                                "channel_policy(%d) extra signed field(%v): %w",
×
4664
                                chanPolicyID, tlvType, err)
×
4665
                }
×
4666
        }
4667

4668
        return nil
×
4669
}
4670

4671
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4672
// provided dbChanRow and also fetches any other required information
4673
// to construct the edge info.
4674
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4675
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4676
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4677

×
4678
        data, err := batchLoadChannelData(
×
4679
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4680
        )
×
4681
        if err != nil {
×
4682
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4683
                        err)
×
4684
        }
×
4685

4686
        return buildEdgeInfoWithBatchData(
×
4687
                cfg.ChainHash, dbChan, node1, node2, data,
×
4688
        )
×
4689
}
4690

4691
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4692
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4693
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4694
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4695

×
NEW
4696
        v := lnwire.GossipVersion(dbChan.Version)
×
NEW
4697
        if !isKnownGossipVersion(v) {
×
NEW
4698
                return nil, fmt.Errorf("unknown channel version: %d", v)
×
UNCOV
4699
        }
×
4700

4701
        // Use pre-loaded features and extras types.
4702
        fv := lnwire.EmptyFeatureVector()
×
4703
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4704
                for _, bit := range features {
×
4705
                        fv.Set(lnwire.FeatureBit(bit))
×
4706
                }
×
4707
        }
4708

4709
        var extras map[uint64][]byte
×
4710
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4711
        if exists {
×
4712
                extras = channelExtras
×
4713
        } else {
×
4714
                extras = make(map[uint64][]byte)
×
4715
        }
×
4716

4717
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4718
        if err != nil {
×
4719
                return nil, err
×
4720
        }
×
4721

4722
        // Build the appropriate channel based on version.
NEW
4723
        var channel *models.ChannelEdgeInfo
×
NEW
4724
        switch v {
×
NEW
4725
        case lnwire.GossipVersion1:
×
NEW
4726
                // For v1, serialize extras into ExtraOpaqueData.
×
NEW
4727
                recs, err := lnwire.CustomRecords(extras).Serialize()
×
NEW
4728
                if err != nil {
×
NEW
4729
                        return nil, fmt.Errorf("unable to serialize extra "+
×
NEW
4730
                                "signed fields: %w", err)
×
NEW
4731
                }
×
NEW
4732
                if recs == nil {
×
NEW
4733
                        recs = make([]byte, 0)
×
NEW
4734
                }
×
4735

4736
                // Bitcoin keys are required for v1.
NEW
4737
                btcKey1, err := route.NewVertexFromBytes(dbChan.BitcoinKey1)
×
NEW
4738
                if err != nil {
×
NEW
4739
                        return nil, err
×
NEW
4740
                }
×
NEW
4741
                btcKey2, err := route.NewVertexFromBytes(dbChan.BitcoinKey2)
×
NEW
4742
                if err != nil {
×
NEW
4743
                        return nil, err
×
NEW
4744
                }
×
4745

NEW
4746
                channel, err = models.NewV1Channel(
×
NEW
4747
                        byteOrder.Uint64(dbChan.Scid), chain, node1, node2,
×
NEW
4748
                        &models.ChannelV1Fields{
×
NEW
4749
                                BitcoinKey1Bytes: btcKey1,
×
NEW
4750
                                BitcoinKey2Bytes: btcKey2,
×
NEW
4751
                                ExtraOpaqueData:  recs,
×
NEW
4752
                        },
×
NEW
4753
                        models.WithChannelPoint(*op),
×
NEW
4754
                        models.WithCapacity(
×
NEW
4755
                                btcutil.Amount(dbChan.Capacity.Int64),
×
NEW
4756
                        ),
×
NEW
4757
                        models.WithFeatures(fv.RawFeatureVector),
×
NEW
4758
                )
×
NEW
4759
                if err != nil {
×
NEW
4760
                        return nil, err
×
NEW
4761
                }
×
4762

4763
                // For v1 channels, attach the auth proof if all four
4764
                // signatures are present.
NEW
4765
                if len(dbChan.Bitcoin1Signature) > 0 {
×
NEW
4766
                        channel.AuthProof = models.NewV1ChannelAuthProof(
×
NEW
4767
                                dbChan.Node1Signature,
×
NEW
4768
                                dbChan.Node2Signature,
×
NEW
4769
                                dbChan.Bitcoin1Signature,
×
NEW
4770
                                dbChan.Bitcoin2Signature,
×
NEW
4771
                        )
×
4772
                }
×
4773

NEW
4774
        case lnwire.GossipVersion2:
×
NEW
4775
                v2Fields := &models.ChannelV2Fields{
×
NEW
4776
                        ExtraSignedFields: extras,
×
NEW
4777
                }
×
NEW
4778

×
NEW
4779
                // For v2, bitcoin keys are optional.
×
NEW
4780
                if len(dbChan.BitcoinKey1) > 0 {
×
NEW
4781
                        btcKey1, err := route.NewVertexFromBytes(
×
NEW
4782
                                dbChan.BitcoinKey1,
×
NEW
4783
                        )
×
NEW
4784
                        if err != nil {
×
NEW
4785
                                return nil, err
×
NEW
4786
                        }
×
NEW
4787
                        v2Fields.BitcoinKey1Bytes = fn.Some(btcKey1)
×
4788
                }
NEW
4789
                if len(dbChan.BitcoinKey2) > 0 {
×
NEW
4790
                        btcKey2, err := route.NewVertexFromBytes(
×
NEW
4791
                                dbChan.BitcoinKey2,
×
NEW
4792
                        )
×
NEW
4793
                        if err != nil {
×
NEW
4794
                                return nil, err
×
NEW
4795
                        }
×
NEW
4796
                        v2Fields.BitcoinKey2Bytes = fn.Some(btcKey2)
×
4797
                }
4798

4799
                // Parse funding script if present.
NEW
4800
                if len(dbChan.FundingPkScript) > 0 {
×
NEW
4801
                        v2Fields.FundingScript = fn.Some(dbChan.FundingPkScript)
×
NEW
4802
                }
×
4803

4804
                // Parse merkle root hash if present.
NEW
4805
                if len(dbChan.MerkleRootHash) > 0 {
×
NEW
4806
                        var hash chainhash.Hash
×
NEW
4807
                        copy(hash[:], dbChan.MerkleRootHash)
×
NEW
4808
                        v2Fields.MerkleRootHash = fn.Some(hash)
×
NEW
4809
                }
×
4810

NEW
4811
                opts := []models.EdgeModifier{
×
NEW
4812
                        models.WithChannelPoint(*op),
×
NEW
4813
                        models.WithCapacity(btcutil.Amount(
×
NEW
4814
                                dbChan.Capacity.Int64,
×
NEW
4815
                        )),
×
NEW
4816
                        models.WithFeatures(fv.RawFeatureVector),
×
NEW
4817
                }
×
NEW
4818

×
NEW
4819
                // For v2 channels, attach the auth proof if the signature is
×
NEW
4820
                // present.
×
NEW
4821
                if len(dbChan.Signature) > 0 {
×
NEW
4822
                        proof := models.NewV2ChannelAuthProof(dbChan.Signature)
×
NEW
4823
                        opts = append(opts, models.WithChanProof(proof))
×
NEW
4824
                }
×
4825

NEW
4826
                channel, err = models.NewV2Channel(
×
NEW
4827
                        byteOrder.Uint64(dbChan.Scid), chain, node1, node2,
×
NEW
4828
                        v2Fields, opts...,
×
NEW
4829
                )
×
NEW
4830
                if err != nil {
×
NEW
4831
                        return nil, err
×
NEW
4832
                }
×
4833

NEW
4834
        default:
×
NEW
4835
                return nil, fmt.Errorf("unsupported channel version: %d", v)
×
4836
        }
4837

4838
        return channel, nil
×
4839
}
4840

4841
// buildNodeVertices is a helper that converts raw node public keys
4842
// into route.Vertex instances.
4843
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4844
        route.Vertex, error) {
×
4845

×
4846
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4847
        if err != nil {
×
4848
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4849
                        "create vertex from node1 pubkey: %w", err)
×
4850
        }
×
4851

4852
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4853
        if err != nil {
×
4854
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4855
                        "create vertex from node2 pubkey: %w", err)
×
4856
        }
×
4857

4858
        return node1Vertex, node2Vertex, nil
×
4859
}
4860

4861
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4862
// retrieves all the extra info required to build the complete
4863
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4864
// the provided sqlc.GraphChannelPolicy records are nil.
4865
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4866
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4867
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4868
        *models.ChannelEdgePolicy, error) {
×
4869

×
4870
        if dbPol1 == nil && dbPol2 == nil {
×
4871
                return nil, nil, nil
×
4872
        }
×
4873

4874
        // TODO(elle): update to support v2 policies.
NEW
4875
        if dbPol1 != nil &&
×
NEW
4876
                lnwire.GossipVersion(dbPol1.Version) != lnwire.GossipVersion1 {
×
NEW
4877

×
NEW
4878
                return nil, nil, fmt.Errorf("unsupported policy1 version: %d",
×
NEW
4879
                        dbPol1.Version)
×
NEW
4880
        }
×
4881

NEW
4882
        if dbPol2 != nil &&
×
NEW
4883
                lnwire.GossipVersion(dbPol2.Version) != lnwire.GossipVersion1 {
×
NEW
4884

×
NEW
4885
                return nil, nil, fmt.Errorf("unsupported policy2 version: %d",
×
NEW
4886
                        dbPol2.Version)
×
NEW
4887
        }
×
4888

4889
        var policyIDs = make([]int64, 0, 2)
×
4890
        if dbPol1 != nil {
×
4891
                policyIDs = append(policyIDs, dbPol1.ID)
×
4892
        }
×
4893
        if dbPol2 != nil {
×
4894
                policyIDs = append(policyIDs, dbPol2.ID)
×
4895
        }
×
4896

4897
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4898
        if err != nil {
×
4899
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4900
                        "data: %w", err)
×
4901
        }
×
4902

4903
        pol1, err := buildChanPolicyWithBatchData(
×
NEW
4904
                true, dbPol1, channelID, node2, batchData,
×
4905
        )
×
4906
        if err != nil {
×
4907
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4908
        }
×
4909

4910
        pol2, err := buildChanPolicyWithBatchData(
×
NEW
4911
                false, dbPol2, channelID, node1, batchData,
×
4912
        )
×
4913
        if err != nil {
×
4914
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4915
        }
×
4916

4917
        return pol1, pol2, nil
×
4918
}
4919

4920
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4921
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4922
// then nil is returned for it.
4923
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4924
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4925
        *models.CachedEdgePolicy, error) {
×
4926

×
4927
        var p1, p2 *models.CachedEdgePolicy
×
4928
        if dbPol1 != nil {
×
NEW
4929
                policy1, err := buildChanPolicy(
×
NEW
4930
                        true, *dbPol1, channelID, nil, node2,
×
NEW
4931
                )
×
4932
                if err != nil {
×
4933
                        return nil, nil, err
×
4934
                }
×
4935

4936
                p1 = models.NewCachedPolicy(policy1)
×
4937
        }
4938
        if dbPol2 != nil {
×
NEW
4939
                policy2, err := buildChanPolicy(
×
NEW
4940
                        false, *dbPol2, channelID, nil, node1,
×
NEW
4941
                )
×
4942
                if err != nil {
×
4943
                        return nil, nil, err
×
4944
                }
×
4945

4946
                p2 = models.NewCachedPolicy(policy2)
×
4947
        }
4948

4949
        return p1, p2, nil
×
4950
}
4951

4952
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4953
// provided sqlc.GraphChannelPolicy and other required information.
4954
func buildChanPolicy(isNode1 bool, dbPolicy sqlc.GraphChannelPolicy,
4955
        channelID uint64, extras map[uint64][]byte,
4956
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4957

×
4958
        var inboundFee fn.Option[lnwire.Fee]
×
4959
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4960
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4961

×
4962
                inboundFee = fn.Some(lnwire.Fee{
×
4963
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4964
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4965
                })
×
4966
        }
×
4967

NEW
4968
        p := &models.ChannelEdgePolicy{
×
NEW
4969
                Version:       lnwire.GossipVersion(dbPolicy.Version),
×
NEW
4970
                SigBytes:      dbPolicy.Signature,
×
NEW
4971
                ChannelID:     channelID,
×
NEW
4972
                SecondPeer:    !isNode1,
×
4973
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4974
                MinHTLC: lnwire.MilliSatoshi(
×
4975
                        dbPolicy.MinHtlcMsat,
×
4976
                ),
×
4977
                MaxHTLC: lnwire.MilliSatoshi(
×
4978
                        dbPolicy.MaxHtlcMsat.Int64,
×
4979
                ),
×
4980
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4981
                        dbPolicy.BaseFeeMsat,
×
4982
                ),
×
4983
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4984
                ToNode:                    toNode,
×
4985
                InboundFee:                inboundFee,
×
NEW
4986
        }
×
NEW
4987

×
NEW
4988
        if p.Version == lnwire.GossipVersion1 {
×
NEW
4989
                recs, err := lnwire.CustomRecords(extras).Serialize()
×
NEW
4990
                if err != nil {
×
NEW
4991
                        return nil, fmt.Errorf("unable to serialize extra "+
×
NEW
4992
                                "signed fields: %w", err)
×
NEW
4993
                }
×
4994

NEW
4995
                p.ExtraOpaqueData = recs
×
NEW
4996
                p.LastUpdate = time.Unix(dbPolicy.LastUpdate.Int64, 0)
×
NEW
4997
                //nolint:ll
×
NEW
4998
                p.MessageFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
NEW
4999
                        dbPolicy.MessageFlags,
×
NEW
5000
                )
×
NEW
5001
                //nolint:ll
×
NEW
5002
                p.ChannelFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
NEW
5003
                        dbPolicy.ChannelFlags,
×
NEW
5004
                )
×
NEW
5005
        } else {
×
NEW
5006
                if dbPolicy.BlockHeight.Valid {
×
NEW
5007
                        p.LastBlockHeight = uint32(
×
NEW
5008
                                dbPolicy.BlockHeight.Int64,
×
NEW
5009
                        )
×
NEW
5010
                }
×
5011

5012
                //nolint:ll
NEW
5013
                p.DisableFlags = sqldb.ExtractSqlInt16[lnwire.ChanUpdateDisableFlags](
×
NEW
5014
                        dbPolicy.DisableFlags,
×
NEW
5015
                )
×
NEW
5016
                p.ExtraSignedFields = extras
×
5017
        }
5018

NEW
5019
        return p, nil
×
5020
}
5021

5022
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
5023
// row which is expected to be a sqlc type that contains channel policy
5024
// information. It returns two policies, which may be nil if the policy
5025
// information is not present in the row.
5026
//
5027
//nolint:ll,dupl,funlen
5028
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
5029
        *sqlc.GraphChannelPolicy, error) {
×
5030

×
5031
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
5032
        switch r := row.(type) {
×
5033
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
5034
                if r.Policy1Timelock.Valid {
×
5035
                        policy1 = &sqlc.GraphChannelPolicy{
×
NEW
5036
                                Version:                 int16(lnwire.GossipVersion1),
×
5037
                                Timelock:                r.Policy1Timelock.Int32,
×
5038
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5039
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5040
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5041
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5042
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5043
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5044
                                Disabled:                r.Policy1Disabled,
×
5045
                                MessageFlags:            r.Policy1MessageFlags,
×
5046
                                ChannelFlags:            r.Policy1ChannelFlags,
×
NEW
5047
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5048
                                DisableFlags:            r.Policy1DisableFlags,
×
5049
                        }
×
5050
                }
×
5051
                if r.Policy2Timelock.Valid {
×
5052
                        policy2 = &sqlc.GraphChannelPolicy{
×
NEW
5053
                                Version:                 int16(lnwire.GossipVersion1),
×
5054
                                Timelock:                r.Policy2Timelock.Int32,
×
5055
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5056
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5057
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5058
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5059
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5060
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5061
                                Disabled:                r.Policy2Disabled,
×
5062
                                MessageFlags:            r.Policy2MessageFlags,
×
5063
                                ChannelFlags:            r.Policy2ChannelFlags,
×
NEW
5064
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5065
                                DisableFlags:            r.Policy2DisableFlags,
×
5066
                        }
×
5067
                }
×
5068

5069
                return policy1, policy2, nil
×
5070

5071
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
5072
                if r.Policy1ID.Valid {
×
5073
                        policy1 = &sqlc.GraphChannelPolicy{
×
5074
                                ID:                      r.Policy1ID.Int64,
×
5075
                                Version:                 r.Policy1Version.Int16,
×
5076
                                ChannelID:               r.GraphChannel.ID,
×
5077
                                NodeID:                  r.Policy1NodeID.Int64,
×
5078
                                Timelock:                r.Policy1Timelock.Int32,
×
5079
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5080
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5081
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5082
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5083
                                LastUpdate:              r.Policy1LastUpdate,
×
5084
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5085
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5086
                                Disabled:                r.Policy1Disabled,
×
5087
                                MessageFlags:            r.Policy1MessageFlags,
×
5088
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5089
                                Signature:               r.Policy1Signature,
×
NEW
5090
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5091
                                DisableFlags:            r.Policy1DisableFlags,
×
5092
                        }
×
5093
                }
×
5094
                if r.Policy2ID.Valid {
×
5095
                        policy2 = &sqlc.GraphChannelPolicy{
×
5096
                                ID:                      r.Policy2ID.Int64,
×
5097
                                Version:                 r.Policy2Version.Int16,
×
5098
                                ChannelID:               r.GraphChannel.ID,
×
5099
                                NodeID:                  r.Policy2NodeID.Int64,
×
5100
                                Timelock:                r.Policy2Timelock.Int32,
×
5101
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5102
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5103
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5104
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5105
                                LastUpdate:              r.Policy2LastUpdate,
×
5106
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5107
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5108
                                Disabled:                r.Policy2Disabled,
×
5109
                                MessageFlags:            r.Policy2MessageFlags,
×
5110
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5111
                                Signature:               r.Policy2Signature,
×
NEW
5112
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5113
                                DisableFlags:            r.Policy2DisableFlags,
×
5114
                        }
×
5115
                }
×
5116

5117
                return policy1, policy2, nil
×
5118

5119
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
5120
                if r.Policy1ID.Valid {
×
5121
                        policy1 = &sqlc.GraphChannelPolicy{
×
5122
                                ID:                      r.Policy1ID.Int64,
×
5123
                                Version:                 r.Policy1Version.Int16,
×
5124
                                ChannelID:               r.GraphChannel.ID,
×
5125
                                NodeID:                  r.Policy1NodeID.Int64,
×
5126
                                Timelock:                r.Policy1Timelock.Int32,
×
5127
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5128
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5129
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5130
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5131
                                LastUpdate:              r.Policy1LastUpdate,
×
5132
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5133
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5134
                                Disabled:                r.Policy1Disabled,
×
5135
                                MessageFlags:            r.Policy1MessageFlags,
×
5136
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5137
                                Signature:               r.Policy1Signature,
×
NEW
5138
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5139
                                DisableFlags:            r.Policy1DisableFlags,
×
5140
                        }
×
5141
                }
×
5142
                if r.Policy2ID.Valid {
×
5143
                        policy2 = &sqlc.GraphChannelPolicy{
×
5144
                                ID:                      r.Policy2ID.Int64,
×
5145
                                Version:                 r.Policy2Version.Int16,
×
5146
                                ChannelID:               r.GraphChannel.ID,
×
5147
                                NodeID:                  r.Policy2NodeID.Int64,
×
5148
                                Timelock:                r.Policy2Timelock.Int32,
×
5149
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5150
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5151
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5152
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5153
                                LastUpdate:              r.Policy2LastUpdate,
×
5154
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5155
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5156
                                Disabled:                r.Policy2Disabled,
×
5157
                                MessageFlags:            r.Policy2MessageFlags,
×
5158
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5159
                                Signature:               r.Policy2Signature,
×
NEW
5160
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5161
                                DisableFlags:            r.Policy2DisableFlags,
×
5162
                        }
×
5163
                }
×
5164

5165
                return policy1, policy2, nil
×
5166

5167
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
5168
                if r.Policy1ID.Valid {
×
5169
                        policy1 = &sqlc.GraphChannelPolicy{
×
5170
                                ID:                      r.Policy1ID.Int64,
×
5171
                                Version:                 r.Policy1Version.Int16,
×
5172
                                ChannelID:               r.GraphChannel.ID,
×
5173
                                NodeID:                  r.Policy1NodeID.Int64,
×
5174
                                Timelock:                r.Policy1Timelock.Int32,
×
5175
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5176
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5177
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5178
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5179
                                LastUpdate:              r.Policy1LastUpdate,
×
5180
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5181
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5182
                                Disabled:                r.Policy1Disabled,
×
5183
                                MessageFlags:            r.Policy1MessageFlags,
×
5184
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5185
                                Signature:               r.Policy1Signature,
×
NEW
5186
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5187
                                DisableFlags:            r.Policy1DisableFlags,
×
5188
                        }
×
5189
                }
×
5190
                if r.Policy2ID.Valid {
×
5191
                        policy2 = &sqlc.GraphChannelPolicy{
×
5192
                                ID:                      r.Policy2ID.Int64,
×
5193
                                Version:                 r.Policy2Version.Int16,
×
5194
                                ChannelID:               r.GraphChannel.ID,
×
5195
                                NodeID:                  r.Policy2NodeID.Int64,
×
5196
                                Timelock:                r.Policy2Timelock.Int32,
×
5197
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5198
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5199
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5200
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5201
                                LastUpdate:              r.Policy2LastUpdate,
×
5202
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5203
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5204
                                Disabled:                r.Policy2Disabled,
×
5205
                                MessageFlags:            r.Policy2MessageFlags,
×
5206
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5207
                                Signature:               r.Policy2Signature,
×
NEW
5208
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5209
                                DisableFlags:            r.Policy2DisableFlags,
×
5210
                        }
×
5211
                }
×
5212

5213
                return policy1, policy2, nil
×
5214

5215
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
5216
                if r.Policy1ID.Valid {
×
5217
                        policy1 = &sqlc.GraphChannelPolicy{
×
5218
                                ID:                      r.Policy1ID.Int64,
×
5219
                                Version:                 r.Policy1Version.Int16,
×
5220
                                ChannelID:               r.GraphChannel.ID,
×
5221
                                NodeID:                  r.Policy1NodeID.Int64,
×
5222
                                Timelock:                r.Policy1Timelock.Int32,
×
5223
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5224
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5225
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5226
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5227
                                LastUpdate:              r.Policy1LastUpdate,
×
5228
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5229
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5230
                                Disabled:                r.Policy1Disabled,
×
5231
                                MessageFlags:            r.Policy1MessageFlags,
×
5232
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5233
                                Signature:               r.Policy1Signature,
×
NEW
5234
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5235
                                DisableFlags:            r.Policy1DisableFlags,
×
5236
                        }
×
5237
                }
×
5238
                if r.Policy2ID.Valid {
×
5239
                        policy2 = &sqlc.GraphChannelPolicy{
×
5240
                                ID:                      r.Policy2ID.Int64,
×
5241
                                Version:                 r.Policy2Version.Int16,
×
5242
                                ChannelID:               r.GraphChannel.ID,
×
5243
                                NodeID:                  r.Policy2NodeID.Int64,
×
5244
                                Timelock:                r.Policy2Timelock.Int32,
×
5245
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5246
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5247
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5248
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5249
                                LastUpdate:              r.Policy2LastUpdate,
×
5250
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5251
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5252
                                Disabled:                r.Policy2Disabled,
×
5253
                                MessageFlags:            r.Policy2MessageFlags,
×
5254
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5255
                                Signature:               r.Policy2Signature,
×
NEW
5256
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5257
                                DisableFlags:            r.Policy2DisableFlags,
×
5258
                        }
×
5259
                }
×
5260

5261
                return policy1, policy2, nil
×
5262

5263
        case sqlc.ListChannelsForNodeIDsRow:
×
5264
                if r.Policy1ID.Valid {
×
5265
                        policy1 = &sqlc.GraphChannelPolicy{
×
5266
                                ID:                      r.Policy1ID.Int64,
×
5267
                                Version:                 r.Policy1Version.Int16,
×
5268
                                ChannelID:               r.GraphChannel.ID,
×
5269
                                NodeID:                  r.Policy1NodeID.Int64,
×
5270
                                Timelock:                r.Policy1Timelock.Int32,
×
5271
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5272
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5273
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5274
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5275
                                LastUpdate:              r.Policy1LastUpdate,
×
5276
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5277
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5278
                                Disabled:                r.Policy1Disabled,
×
5279
                                MessageFlags:            r.Policy1MessageFlags,
×
5280
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5281
                                Signature:               r.Policy1Signature,
×
NEW
5282
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5283
                                DisableFlags:            r.Policy1DisableFlags,
×
5284
                        }
×
5285
                }
×
5286
                if r.Policy2ID.Valid {
×
5287
                        policy2 = &sqlc.GraphChannelPolicy{
×
5288
                                ID:                      r.Policy2ID.Int64,
×
5289
                                Version:                 r.Policy2Version.Int16,
×
5290
                                ChannelID:               r.GraphChannel.ID,
×
5291
                                NodeID:                  r.Policy2NodeID.Int64,
×
5292
                                Timelock:                r.Policy2Timelock.Int32,
×
5293
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5294
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5295
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5296
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5297
                                LastUpdate:              r.Policy2LastUpdate,
×
5298
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5299
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5300
                                Disabled:                r.Policy2Disabled,
×
5301
                                MessageFlags:            r.Policy2MessageFlags,
×
5302
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5303
                                Signature:               r.Policy2Signature,
×
NEW
5304
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5305
                                DisableFlags:            r.Policy2DisableFlags,
×
5306
                        }
×
5307
                }
×
5308

5309
                return policy1, policy2, nil
×
5310

5311
        case sqlc.ListChannelsByNodeIDRow:
×
5312
                if r.Policy1ID.Valid {
×
5313
                        policy1 = &sqlc.GraphChannelPolicy{
×
5314
                                ID:                      r.Policy1ID.Int64,
×
5315
                                Version:                 r.Policy1Version.Int16,
×
5316
                                ChannelID:               r.GraphChannel.ID,
×
5317
                                NodeID:                  r.Policy1NodeID.Int64,
×
5318
                                Timelock:                r.Policy1Timelock.Int32,
×
5319
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5320
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5321
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5322
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5323
                                LastUpdate:              r.Policy1LastUpdate,
×
5324
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5325
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5326
                                Disabled:                r.Policy1Disabled,
×
5327
                                MessageFlags:            r.Policy1MessageFlags,
×
5328
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5329
                                Signature:               r.Policy1Signature,
×
NEW
5330
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5331
                                DisableFlags:            r.Policy1DisableFlags,
×
5332
                        }
×
5333
                }
×
5334
                if r.Policy2ID.Valid {
×
5335
                        policy2 = &sqlc.GraphChannelPolicy{
×
5336
                                ID:                      r.Policy2ID.Int64,
×
5337
                                Version:                 r.Policy2Version.Int16,
×
5338
                                ChannelID:               r.GraphChannel.ID,
×
5339
                                NodeID:                  r.Policy2NodeID.Int64,
×
5340
                                Timelock:                r.Policy2Timelock.Int32,
×
5341
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5342
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5343
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5344
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5345
                                LastUpdate:              r.Policy2LastUpdate,
×
5346
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5347
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5348
                                Disabled:                r.Policy2Disabled,
×
5349
                                MessageFlags:            r.Policy2MessageFlags,
×
5350
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5351
                                Signature:               r.Policy2Signature,
×
NEW
5352
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5353
                                DisableFlags:            r.Policy2DisableFlags,
×
5354
                        }
×
5355
                }
×
5356

5357
                return policy1, policy2, nil
×
5358

5359
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
5360
                if r.Policy1ID.Valid {
×
5361
                        policy1 = &sqlc.GraphChannelPolicy{
×
5362
                                ID:                      r.Policy1ID.Int64,
×
5363
                                Version:                 r.Policy1Version.Int16,
×
5364
                                ChannelID:               r.GraphChannel.ID,
×
5365
                                NodeID:                  r.Policy1NodeID.Int64,
×
5366
                                Timelock:                r.Policy1Timelock.Int32,
×
5367
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5368
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5369
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5370
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5371
                                LastUpdate:              r.Policy1LastUpdate,
×
5372
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5373
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5374
                                Disabled:                r.Policy1Disabled,
×
5375
                                MessageFlags:            r.Policy1MessageFlags,
×
5376
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5377
                                Signature:               r.Policy1Signature,
×
NEW
5378
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5379
                                DisableFlags:            r.Policy1DisableFlags,
×
5380
                        }
×
5381
                }
×
5382
                if r.Policy2ID.Valid {
×
5383
                        policy2 = &sqlc.GraphChannelPolicy{
×
5384
                                ID:                      r.Policy2ID.Int64,
×
5385
                                Version:                 r.Policy2Version.Int16,
×
5386
                                ChannelID:               r.GraphChannel.ID,
×
5387
                                NodeID:                  r.Policy2NodeID.Int64,
×
5388
                                Timelock:                r.Policy2Timelock.Int32,
×
5389
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5390
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5391
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5392
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5393
                                LastUpdate:              r.Policy2LastUpdate,
×
5394
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5395
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5396
                                Disabled:                r.Policy2Disabled,
×
5397
                                MessageFlags:            r.Policy2MessageFlags,
×
5398
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5399
                                Signature:               r.Policy2Signature,
×
NEW
5400
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5401
                                DisableFlags:            r.Policy2DisableFlags,
×
5402
                        }
×
5403
                }
×
5404

5405
                return policy1, policy2, nil
×
5406

5407
        case sqlc.GetChannelsByIDsRow:
×
5408
                if r.Policy1ID.Valid {
×
5409
                        policy1 = &sqlc.GraphChannelPolicy{
×
5410
                                ID:                      r.Policy1ID.Int64,
×
5411
                                Version:                 r.Policy1Version.Int16,
×
5412
                                ChannelID:               r.GraphChannel.ID,
×
5413
                                NodeID:                  r.Policy1NodeID.Int64,
×
5414
                                Timelock:                r.Policy1Timelock.Int32,
×
5415
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
5416
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
5417
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
5418
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
5419
                                LastUpdate:              r.Policy1LastUpdate,
×
5420
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
5421
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
5422
                                Disabled:                r.Policy1Disabled,
×
5423
                                MessageFlags:            r.Policy1MessageFlags,
×
5424
                                ChannelFlags:            r.Policy1ChannelFlags,
×
5425
                                Signature:               r.Policy1Signature,
×
NEW
5426
                                BlockHeight:             r.Policy1BlockHeight,
×
NEW
5427
                                DisableFlags:            r.Policy1DisableFlags,
×
5428
                        }
×
5429
                }
×
5430
                if r.Policy2ID.Valid {
×
5431
                        policy2 = &sqlc.GraphChannelPolicy{
×
5432
                                ID:                      r.Policy2ID.Int64,
×
5433
                                Version:                 r.Policy2Version.Int16,
×
5434
                                ChannelID:               r.GraphChannel.ID,
×
5435
                                NodeID:                  r.Policy2NodeID.Int64,
×
5436
                                Timelock:                r.Policy2Timelock.Int32,
×
5437
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
5438
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
5439
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
5440
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
5441
                                LastUpdate:              r.Policy2LastUpdate,
×
5442
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
5443
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5444
                                Disabled:                r.Policy2Disabled,
×
5445
                                MessageFlags:            r.Policy2MessageFlags,
×
5446
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5447
                                Signature:               r.Policy2Signature,
×
NEW
5448
                                BlockHeight:             r.Policy2BlockHeight,
×
NEW
5449
                                DisableFlags:            r.Policy2DisableFlags,
×
5450
                        }
×
5451
                }
×
5452

5453
                return policy1, policy2, nil
×
5454

5455
        default:
×
5456
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
5457
                        "extractChannelPolicies: %T", r)
×
5458
        }
5459
}
5460

5461
// channelIDToBytes converts a channel ID (SCID) to a byte array
5462
// representation.
5463
func channelIDToBytes(channelID uint64) []byte {
×
5464
        var chanIDB [8]byte
×
5465
        byteOrder.PutUint64(chanIDB[:], channelID)
×
5466

×
5467
        return chanIDB[:]
×
5468
}
×
5469

5470
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
5471
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
5472
        if len(addresses) == 0 {
×
5473
                return nil, nil
×
5474
        }
×
5475

5476
        result := make([]net.Addr, 0, len(addresses))
×
5477
        for _, addr := range addresses {
×
5478
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
5479
                if err != nil {
×
5480
                        return nil, fmt.Errorf("unable to parse address %s "+
×
5481
                                "of type %d: %w", addr.address, addr.addrType,
×
5482
                                err)
×
5483
                }
×
5484
                if netAddr != nil {
×
5485
                        result = append(result, netAddr)
×
5486
                }
×
5487
        }
5488

5489
        // If we have no valid addresses, return nil instead of empty slice.
5490
        if len(result) == 0 {
×
5491
                return nil, nil
×
5492
        }
×
5493

5494
        return result, nil
×
5495
}
5496

5497
// parseAddress parses the given address string based on the address type
5498
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
5499
// and opaque addresses.
5500
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
5501
        switch addrType {
×
5502
        case addressTypeIPv4:
×
5503
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
5504
                if err != nil {
×
5505
                        return nil, err
×
5506
                }
×
5507

5508
                tcp.IP = tcp.IP.To4()
×
5509

×
5510
                return tcp, nil
×
5511

5512
        case addressTypeIPv6:
×
5513
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
5514
                if err != nil {
×
5515
                        return nil, err
×
5516
                }
×
5517

5518
                return tcp, nil
×
5519

5520
        case addressTypeTorV3, addressTypeTorV2:
×
5521
                service, portStr, err := net.SplitHostPort(address)
×
5522
                if err != nil {
×
5523
                        return nil, fmt.Errorf("unable to split tor "+
×
5524
                                "address: %v", address)
×
5525
                }
×
5526

5527
                port, err := strconv.Atoi(portStr)
×
5528
                if err != nil {
×
5529
                        return nil, err
×
5530
                }
×
5531

5532
                return &tor.OnionAddr{
×
5533
                        OnionService: service,
×
5534
                        Port:         port,
×
5535
                }, nil
×
5536

5537
        case addressTypeDNS:
×
5538
                hostname, portStr, err := net.SplitHostPort(address)
×
5539
                if err != nil {
×
5540
                        return nil, fmt.Errorf("unable to split DNS "+
×
5541
                                "address: %v", address)
×
5542
                }
×
5543

5544
                port, err := strconv.Atoi(portStr)
×
5545
                if err != nil {
×
5546
                        return nil, err
×
5547
                }
×
5548

5549
                return &lnwire.DNSAddress{
×
5550
                        Hostname: hostname,
×
5551
                        Port:     uint16(port),
×
5552
                }, nil
×
5553

5554
        case addressTypeOpaque:
×
5555
                opaque, err := hex.DecodeString(address)
×
5556
                if err != nil {
×
5557
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5558
                                "address: %v", address)
×
5559
                }
×
5560

5561
                return &lnwire.OpaqueAddrs{
×
5562
                        Payload: opaque,
×
5563
                }, nil
×
5564

5565
        default:
×
5566
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5567
        }
5568
}
5569

5570
// batchNodeData holds all the related data for a batch of nodes.
5571
type batchNodeData struct {
5572
        // features is a map from a DB node ID to the feature bits for that
5573
        // node.
5574
        features map[int64][]int
5575

5576
        // addresses is a map from a DB node ID to the node's addresses.
5577
        addresses map[int64][]nodeAddress
5578

5579
        // extraFields is a map from a DB node ID to the extra signed fields
5580
        // for that node.
5581
        extraFields map[int64]map[uint64][]byte
5582
}
5583

5584
// nodeAddress holds the address type, position and address string for a
5585
// node. This is used to batch the fetching of node addresses.
5586
type nodeAddress struct {
5587
        addrType dbAddressType
5588
        position int32
5589
        address  string
5590
}
5591

5592
// batchLoadNodeData loads all related data for a batch of node IDs using the
5593
// provided SQLQueries interface. It returns a batchNodeData instance containing
5594
// the node features, addresses and extra signed fields.
5595
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5596
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5597

×
5598
        // Batch load the node features.
×
5599
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5600
        if err != nil {
×
5601
                return nil, fmt.Errorf("unable to batch load node "+
×
5602
                        "features: %w", err)
×
5603
        }
×
5604

5605
        // Batch load the node addresses.
5606
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5607
        if err != nil {
×
5608
                return nil, fmt.Errorf("unable to batch load node "+
×
5609
                        "addresses: %w", err)
×
5610
        }
×
5611

5612
        // Batch load the node extra signed fields.
5613
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5614
        if err != nil {
×
5615
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5616
                        "signed fields: %w", err)
×
5617
        }
×
5618

5619
        return &batchNodeData{
×
5620
                features:    features,
×
5621
                addresses:   addrs,
×
5622
                extraFields: extraTypes,
×
5623
        }, nil
×
5624
}
5625

5626
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5627
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5628
func batchLoadNodeFeaturesHelper(ctx context.Context,
5629
        cfg *sqldb.QueryConfig, db SQLQueries,
5630
        nodeIDs []int64) (map[int64][]int, error) {
×
5631

×
5632
        features := make(map[int64][]int)
×
5633

×
5634
        return features, sqldb.ExecuteBatchQuery(
×
5635
                ctx, cfg, nodeIDs,
×
5636
                func(id int64) int64 {
×
5637
                        return id
×
5638
                },
×
5639
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5640
                        error) {
×
5641

×
5642
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5643
                },
×
5644
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5645
                        features[feature.NodeID] = append(
×
5646
                                features[feature.NodeID],
×
5647
                                int(feature.FeatureBit),
×
5648
                        )
×
5649

×
5650
                        return nil
×
5651
                },
×
5652
        )
5653
}
5654

5655
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5656
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5657
// node ID to a slice of nodeAddress structs.
5658
func batchLoadNodeAddressesHelper(ctx context.Context,
5659
        cfg *sqldb.QueryConfig, db SQLQueries,
5660
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5661

×
5662
        addrs := make(map[int64][]nodeAddress)
×
5663

×
5664
        return addrs, sqldb.ExecuteBatchQuery(
×
5665
                ctx, cfg, nodeIDs,
×
5666
                func(id int64) int64 {
×
5667
                        return id
×
5668
                },
×
5669
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5670
                        error) {
×
5671

×
5672
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5673
                },
×
5674
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5675
                        addrs[addr.NodeID] = append(
×
5676
                                addrs[addr.NodeID], nodeAddress{
×
5677
                                        addrType: dbAddressType(addr.Type),
×
5678
                                        position: addr.Position,
×
5679
                                        address:  addr.Address,
×
5680
                                },
×
5681
                        )
×
5682

×
5683
                        return nil
×
5684
                },
×
5685
        )
5686
}
5687

5688
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5689
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5690
// query.
5691
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5692
        cfg *sqldb.QueryConfig, db SQLQueries,
5693
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5694

×
5695
        extraFields := make(map[int64]map[uint64][]byte)
×
5696

×
5697
        callback := func(ctx context.Context,
×
5698
                field sqlc.GraphNodeExtraType) error {
×
5699

×
5700
                if extraFields[field.NodeID] == nil {
×
5701
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5702
                }
×
5703
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5704

×
5705
                return nil
×
5706
        }
5707

5708
        return extraFields, sqldb.ExecuteBatchQuery(
×
5709
                ctx, cfg, nodeIDs,
×
5710
                func(id int64) int64 {
×
5711
                        return id
×
5712
                },
×
5713
                func(ctx context.Context, ids []int64) (
5714
                        []sqlc.GraphNodeExtraType, error) {
×
5715

×
5716
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5717
                },
×
5718
                callback,
5719
        )
5720
}
5721

5722
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5723
// from the provided sqlc.GraphChannelPolicy records and the
5724
// provided batchChannelData.
5725
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5726
        channelID uint64, node1, node2 route.Vertex,
5727
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5728
        *models.ChannelEdgePolicy, error) {
×
5729

×
5730
        pol1, err := buildChanPolicyWithBatchData(
×
NEW
5731
                true, dbPol1, channelID, node2, batchData,
×
5732
        )
×
5733
        if err != nil {
×
5734
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5735
        }
×
5736

5737
        pol2, err := buildChanPolicyWithBatchData(
×
NEW
5738
                false, dbPol2, channelID, node1, batchData,
×
5739
        )
×
5740
        if err != nil {
×
5741
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5742
        }
×
5743

5744
        return pol1, pol2, nil
×
5745
}
5746

5747
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5748
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5749
func buildChanPolicyWithBatchData(isNode1 bool,
5750
        dbPol *sqlc.GraphChannelPolicy, channelID uint64,
5751
        toNode route.Vertex, batchData *batchChannelData) (
NEW
5752
        *models.ChannelEdgePolicy, error) {
×
5753

×
5754
        if dbPol == nil {
×
5755
                return nil, nil
×
5756
        }
×
5757

5758
        var dbPol1Extras map[uint64][]byte
×
5759
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5760
                dbPol1Extras = extras
×
5761
        } else {
×
5762
                dbPol1Extras = make(map[uint64][]byte)
×
5763
        }
×
5764

NEW
5765
        return buildChanPolicy(isNode1, *dbPol, channelID, dbPol1Extras, toNode)
×
5766
}
5767

5768
// batchChannelData holds all the related data for a batch of channels.
5769
type batchChannelData struct {
5770
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5771
        chanfeatures map[int64][]int
5772

5773
        // chanExtras is a map from DB channel ID to a map of TLV type to
5774
        // extra signed field bytes.
5775
        chanExtraTypes map[int64]map[uint64][]byte
5776

5777
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5778
        // to extra signed field bytes.
5779
        policyExtras map[int64]map[uint64][]byte
5780
}
5781

5782
// batchLoadChannelData loads all related data for batches of channels and
5783
// policies.
5784
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5785
        db SQLQueries, channelIDs []int64,
5786
        policyIDs []int64) (*batchChannelData, error) {
×
5787

×
5788
        batchData := &batchChannelData{
×
5789
                chanfeatures:   make(map[int64][]int),
×
5790
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5791
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5792
        }
×
5793

×
5794
        // Batch load channel features and extras
×
5795
        var err error
×
5796
        if len(channelIDs) > 0 {
×
5797
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5798
                        ctx, cfg, db, channelIDs,
×
5799
                )
×
5800
                if err != nil {
×
5801
                        return nil, fmt.Errorf("unable to batch load "+
×
5802
                                "channel features: %w", err)
×
5803
                }
×
5804

5805
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5806
                        ctx, cfg, db, channelIDs,
×
5807
                )
×
5808
                if err != nil {
×
5809
                        return nil, fmt.Errorf("unable to batch load "+
×
5810
                                "channel extras: %w", err)
×
5811
                }
×
5812
        }
5813

5814
        if len(policyIDs) > 0 {
×
5815
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5816
                        ctx, cfg, db, policyIDs,
×
5817
                )
×
5818
                if err != nil {
×
5819
                        return nil, fmt.Errorf("unable to batch load "+
×
5820
                                "policy extras: %w", err)
×
5821
                }
×
5822
                batchData.policyExtras = policyExtras
×
5823
        }
5824

5825
        return batchData, nil
×
5826
}
5827

5828
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5829
// channel IDs using ExecuteBatchQuery wrapper around the
5830
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5831
// slice of feature bits.
5832
func batchLoadChannelFeaturesHelper(ctx context.Context,
5833
        cfg *sqldb.QueryConfig, db SQLQueries,
5834
        channelIDs []int64) (map[int64][]int, error) {
×
5835

×
5836
        features := make(map[int64][]int)
×
5837

×
5838
        return features, sqldb.ExecuteBatchQuery(
×
5839
                ctx, cfg, channelIDs,
×
5840
                func(id int64) int64 {
×
5841
                        return id
×
5842
                },
×
5843
                func(ctx context.Context,
5844
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5845

×
5846
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5847
                },
×
5848
                func(ctx context.Context,
5849
                        feature sqlc.GraphChannelFeature) error {
×
5850

×
5851
                        features[feature.ChannelID] = append(
×
5852
                                features[feature.ChannelID],
×
5853
                                int(feature.FeatureBit),
×
5854
                        )
×
5855

×
5856
                        return nil
×
5857
                },
×
5858
        )
5859
}
5860

5861
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5862
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5863
// query. It returns a map from DB channel ID to a map of TLV type to extra
5864
// signed field bytes.
5865
func batchLoadChannelExtrasHelper(ctx context.Context,
5866
        cfg *sqldb.QueryConfig, db SQLQueries,
5867
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5868

×
5869
        extras := make(map[int64]map[uint64][]byte)
×
5870

×
5871
        cb := func(ctx context.Context,
×
5872
                extra sqlc.GraphChannelExtraType) error {
×
5873

×
5874
                if extras[extra.ChannelID] == nil {
×
5875
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5876
                }
×
5877
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5878

×
5879
                return nil
×
5880
        }
5881

5882
        return extras, sqldb.ExecuteBatchQuery(
×
5883
                ctx, cfg, channelIDs,
×
5884
                func(id int64) int64 {
×
5885
                        return id
×
5886
                },
×
5887
                func(ctx context.Context,
5888
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5889

×
5890
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5891
                }, cb,
×
5892
        )
5893
}
5894

5895
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5896
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5897
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5898
// a map of TLV type to extra signed field bytes.
5899
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5900
        cfg *sqldb.QueryConfig, db SQLQueries,
5901
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5902

×
5903
        extras := make(map[int64]map[uint64][]byte)
×
5904

×
5905
        return extras, sqldb.ExecuteBatchQuery(
×
5906
                ctx, cfg, policyIDs,
×
5907
                func(id int64) int64 {
×
5908
                        return id
×
5909
                },
×
5910
                func(ctx context.Context, ids []int64) (
5911
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5912

×
5913
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5914
                },
×
5915
                func(ctx context.Context,
5916
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5917

×
5918
                        if extras[row.PolicyID] == nil {
×
5919
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5920
                        }
×
5921
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5922

×
5923
                        return nil
×
5924
                },
5925
        )
5926
}
5927

5928
// forEachNodePaginated executes a paginated query to process each node in the
5929
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5930
// and applies the provided processNode function to each node.
5931
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5932
        db SQLQueries, protocol lnwire.GossipVersion,
5933
        processNode func(context.Context, int64,
5934
                *models.Node) error) error {
×
5935

×
5936
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5937
                limit int32) ([]sqlc.GraphNode, error) {
×
5938

×
5939
                return db.ListNodesPaginated(
×
5940
                        ctx, sqlc.ListNodesPaginatedParams{
×
5941
                                Version: int16(protocol),
×
5942
                                ID:      lastID,
×
5943
                                Limit:   limit,
×
5944
                        },
×
5945
                )
×
5946
        }
×
5947

5948
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5949
                return node.ID
×
5950
        }
×
5951

5952
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5953
                return node.ID, nil
×
5954
        }
×
5955

5956
        batchQueryFunc := func(ctx context.Context,
×
5957
                nodeIDs []int64) (*batchNodeData, error) {
×
5958

×
5959
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5960
        }
×
5961

5962
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5963
                batchData *batchNodeData) error {
×
5964

×
5965
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5966
                if err != nil {
×
5967
                        return fmt.Errorf("unable to build "+
×
5968
                                "node(id=%d): %w", dbNode.ID, err)
×
5969
                }
×
5970

5971
                return processNode(ctx, dbNode.ID, node)
×
5972
        }
5973

5974
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5975
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5976
                collectFunc, batchQueryFunc, processItem,
×
5977
        )
×
5978
}
5979

5980
// forEachChannelWithPolicies executes a paginated query to process each channel
5981
// with policies in the graph.
5982
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5983
        cfg *SQLStoreConfig, v lnwire.GossipVersion,
5984
        processChannel func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
5985
                *models.ChannelEdgePolicy) error) error {
×
5986

×
5987
        type channelBatchIDs struct {
×
5988
                channelID int64
×
5989
                policyIDs []int64
×
5990
        }
×
5991

×
5992
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5993
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5994
                error) {
×
5995

×
5996
                return db.ListChannelsWithPoliciesPaginated(
×
5997
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
NEW
5998
                                Version: int16(v),
×
5999
                                ID:      lastID,
×
6000
                                Limit:   limit,
×
6001
                        },
×
6002
                )
×
6003
        }
×
6004

6005
        extractPageCursor := func(
×
6006
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
6007

×
6008
                return row.GraphChannel.ID
×
6009
        }
×
6010

6011
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
6012
                channelBatchIDs, error) {
×
6013

×
6014
                ids := channelBatchIDs{
×
6015
                        channelID: row.GraphChannel.ID,
×
6016
                }
×
6017

×
6018
                // Extract policy IDs from the row.
×
6019
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
6020
                if err != nil {
×
6021
                        return ids, err
×
6022
                }
×
6023

6024
                if dbPol1 != nil {
×
6025
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
6026
                }
×
6027
                if dbPol2 != nil {
×
6028
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
6029
                }
×
6030

6031
                return ids, nil
×
6032
        }
6033

6034
        batchDataFunc := func(ctx context.Context,
×
6035
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
6036

×
6037
                // Separate channel IDs from policy IDs.
×
6038
                var (
×
6039
                        channelIDs = make([]int64, len(allIDs))
×
6040
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
6041
                )
×
6042

×
6043
                for i, ids := range allIDs {
×
6044
                        channelIDs[i] = ids.channelID
×
6045
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
6046
                }
×
6047

6048
                return batchLoadChannelData(
×
6049
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
6050
                )
×
6051
        }
6052

6053
        processItem := func(ctx context.Context,
×
6054
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
6055
                batchData *batchChannelData) error {
×
6056

×
6057
                node1, node2, err := buildNodeVertices(
×
6058
                        row.Node1Pubkey, row.Node2Pubkey,
×
6059
                )
×
6060
                if err != nil {
×
6061
                        return err
×
6062
                }
×
6063

6064
                edge, err := buildEdgeInfoWithBatchData(
×
6065
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
6066
                        batchData,
×
6067
                )
×
6068
                if err != nil {
×
6069
                        return fmt.Errorf("unable to build channel info: %w",
×
6070
                                err)
×
6071
                }
×
6072

6073
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
6074
                if err != nil {
×
6075
                        return err
×
6076
                }
×
6077

6078
                p1, p2, err := buildChanPoliciesWithBatchData(
×
6079
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
6080
                )
×
6081
                if err != nil {
×
6082
                        return err
×
6083
                }
×
6084

6085
                return processChannel(edge, p1, p2)
×
6086
        }
6087

6088
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
6089
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
6090
                collectFunc, batchDataFunc, processItem,
×
6091
        )
×
6092
}
6093

6094
// buildDirectedChannel builds a DirectedChannel instance from the provided
6095
// data.
6096
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
6097
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
6098
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
6099
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
6100

×
6101
        node1, node2, err := buildNodeVertices(
×
6102
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
6103
        )
×
6104
        if err != nil {
×
6105
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
6106
        }
×
6107

6108
        edge, err := buildEdgeInfoWithBatchData(
×
6109
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
6110
        )
×
6111
        if err != nil {
×
6112
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
6113
        }
×
6114

6115
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
6116
        if err != nil {
×
6117
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
6118
                        err)
×
6119
        }
×
6120

6121
        p1, p2, err := buildChanPoliciesWithBatchData(
×
6122
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
6123
                channelBatchData,
×
6124
        )
×
6125
        if err != nil {
×
6126
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
6127
                        err)
×
6128
        }
×
6129

6130
        // Determine outgoing and incoming policy for this specific node.
6131
        p1ToNode := channelRow.GraphChannel.NodeID2
×
6132
        p2ToNode := channelRow.GraphChannel.NodeID1
×
6133
        outPolicy, inPolicy := p1, p2
×
6134
        if (p1 != nil && p1ToNode == nodeID) ||
×
6135
                (p2 != nil && p2ToNode != nodeID) {
×
6136

×
6137
                outPolicy, inPolicy = p2, p1
×
6138
        }
×
6139

6140
        // Build cached policy.
6141
        var cachedInPolicy *models.CachedEdgePolicy
×
6142
        if inPolicy != nil {
×
6143
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
6144
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
6145
                cachedInPolicy.ToNodeFeatures = features
×
6146
        }
×
6147

6148
        // Extract inbound fee.
6149
        var inboundFee lnwire.Fee
×
6150
        if outPolicy != nil {
×
6151
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
6152
                        inboundFee = fee
×
6153
                })
×
6154
        }
6155

6156
        // Build directed channel.
6157
        directedChannel := &DirectedChannel{
×
6158
                ChannelID:    edge.ChannelID,
×
6159
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
6160
                OtherNode:    edge.NodeKey2Bytes,
×
6161
                Capacity:     edge.Capacity,
×
6162
                OutPolicySet: outPolicy != nil,
×
6163
                InPolicy:     cachedInPolicy,
×
6164
                InboundFee:   inboundFee,
×
6165
        }
×
6166

×
6167
        if nodePub == edge.NodeKey2Bytes {
×
6168
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
6169
        }
×
6170

6171
        return directedChannel, nil
×
6172
}
6173

6174
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
6175
// provided rows. It uses batch loading for channels, policies, and nodes.
6176
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
6177
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
6178

×
6179
        var (
×
6180
                channelIDs = make([]int64, len(rows))
×
6181
                policyIDs  = make([]int64, 0, len(rows)*2)
×
6182
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
6183

×
6184
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
6185
                nodeIDSet = make(map[int64]bool)
×
6186

×
6187
                // edges will hold the final channel edges built from the rows.
×
6188
                edges = make([]ChannelEdge, 0, len(rows))
×
6189
        )
×
6190

×
6191
        // Collect all IDs needed for batch loading.
×
6192
        for i, row := range rows {
×
6193
                channelIDs[i] = row.Channel().ID
×
6194

×
6195
                // Collect policy IDs
×
6196
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
6197
                if err != nil {
×
6198
                        return nil, fmt.Errorf("unable to extract channel "+
×
6199
                                "policies: %w", err)
×
6200
                }
×
6201
                if dbPol1 != nil {
×
6202
                        policyIDs = append(policyIDs, dbPol1.ID)
×
6203
                }
×
6204
                if dbPol2 != nil {
×
6205
                        policyIDs = append(policyIDs, dbPol2.ID)
×
6206
                }
×
6207

6208
                var (
×
6209
                        node1ID = row.Node1().ID
×
6210
                        node2ID = row.Node2().ID
×
6211
                )
×
6212

×
6213
                // Collect unique node IDs.
×
6214
                if !nodeIDSet[node1ID] {
×
6215
                        nodeIDs = append(nodeIDs, node1ID)
×
6216
                        nodeIDSet[node1ID] = true
×
6217
                }
×
6218

6219
                if !nodeIDSet[node2ID] {
×
6220
                        nodeIDs = append(nodeIDs, node2ID)
×
6221
                        nodeIDSet[node2ID] = true
×
6222
                }
×
6223
        }
6224

6225
        // Batch the data for all the channels and policies.
6226
        channelBatchData, err := batchLoadChannelData(
×
6227
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
6228
        )
×
6229
        if err != nil {
×
6230
                return nil, fmt.Errorf("unable to batch load channel and "+
×
6231
                        "policy data: %w", err)
×
6232
        }
×
6233

6234
        // Batch the data for all the nodes.
6235
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
6236
        if err != nil {
×
6237
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
6238
                        err)
×
6239
        }
×
6240

6241
        // Build all channel edges using batch data.
6242
        for _, row := range rows {
×
6243
                // Build nodes using batch data.
×
6244
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
6245
                if err != nil {
×
6246
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
6247
                }
×
6248

6249
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
6250
                if err != nil {
×
6251
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
6252
                }
×
6253

6254
                // Build channel info using batch data.
6255
                channel, err := buildEdgeInfoWithBatchData(
×
6256
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
6257
                        node2.PubKeyBytes, channelBatchData,
×
6258
                )
×
6259
                if err != nil {
×
6260
                        return nil, fmt.Errorf("unable to build channel "+
×
6261
                                "info: %w", err)
×
6262
                }
×
6263

6264
                // Extract and build policies using batch data.
6265
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
6266
                if err != nil {
×
6267
                        return nil, fmt.Errorf("unable to extract channel "+
×
6268
                                "policies: %w", err)
×
6269
                }
×
6270

6271
                p1, p2, err := buildChanPoliciesWithBatchData(
×
6272
                        dbPol1, dbPol2, channel.ChannelID,
×
6273
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
6274
                )
×
6275
                if err != nil {
×
6276
                        return nil, fmt.Errorf("unable to build channel "+
×
6277
                                "policies: %w", err)
×
6278
                }
×
6279

6280
                edges = append(edges, ChannelEdge{
×
6281
                        Info:    channel,
×
6282
                        Policy1: p1,
×
6283
                        Policy2: p2,
×
6284
                        Node1:   node1,
×
6285
                        Node2:   node2,
×
6286
                })
×
6287
        }
6288

6289
        return edges, nil
×
6290
}
6291

6292
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
6293
// instances from the provided rows using batch loading for channel data.
6294
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
6295
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
6296
        []*models.ChannelEdgeInfo, []int64, error) {
×
6297

×
6298
        if len(rows) == 0 {
×
6299
                return nil, nil, nil
×
6300
        }
×
6301

6302
        // Collect all the channel IDs needed for batch loading.
6303
        channelIDs := make([]int64, len(rows))
×
6304
        for i, row := range rows {
×
6305
                channelIDs[i] = row.Channel().ID
×
6306
        }
×
6307

6308
        // Batch load the channel data.
6309
        channelBatchData, err := batchLoadChannelData(
×
6310
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
6311
        )
×
6312
        if err != nil {
×
6313
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
6314
                        "data: %w", err)
×
6315
        }
×
6316

6317
        // Build all channel edges using batch data.
6318
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
6319
        for _, row := range rows {
×
6320
                node1, node2, err := buildNodeVertices(
×
6321
                        row.Node1Pub(), row.Node2Pub(),
×
6322
                )
×
6323
                if err != nil {
×
6324
                        return nil, nil, err
×
6325
                }
×
6326

6327
                // Build channel info using batch data
6328
                info, err := buildEdgeInfoWithBatchData(
×
6329
                        cfg.ChainHash, row.Channel(), node1, node2,
×
6330
                        channelBatchData,
×
6331
                )
×
6332
                if err != nil {
×
6333
                        return nil, nil, err
×
6334
                }
×
6335

6336
                edges = append(edges, info)
×
6337
        }
6338

6339
        return edges, channelIDs, nil
×
6340
}
6341

6342
// handleZombieMarking is a helper function that handles the logic of
6343
// marking a channel as a zombie in the database. It takes into account whether
6344
// we are in strict zombie pruning mode, and adjusts the node public keys
6345
// accordingly based on the last update timestamps of the channel policies.
6346
func handleZombieMarking(ctx context.Context, db SQLQueries,
6347
        v lnwire.GossipVersion,
6348
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
6349
        strictZombiePruning bool, scid uint64) error {
×
6350

×
6351
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
6352

×
6353
        if strictZombiePruning {
×
NEW
6354
                // TODO(elle): update for V2 last update times.
×
NEW
6355
                if v != lnwire.GossipVersion1 {
×
NEW
6356
                        return fmt.Errorf("strict zombie pruning only "+
×
NEW
6357
                                "supported for gossip v1, got %v", v)
×
NEW
6358
                }
×
6359

6360
                var e1UpdateTime, e2UpdateTime *time.Time
×
6361
                if row.Policy1LastUpdate.Valid {
×
6362
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
6363
                        e1UpdateTime = &e1Time
×
6364
                }
×
6365
                if row.Policy2LastUpdate.Valid {
×
6366
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
6367
                        e2UpdateTime = &e2Time
×
6368
                }
×
6369

6370
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
6371
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
6372
                        e2UpdateTime,
×
6373
                )
×
6374
        }
6375

6376
        return db.UpsertZombieChannel(
×
6377
                ctx, sqlc.UpsertZombieChannelParams{
×
NEW
6378
                        Version:  int16(v),
×
6379
                        Scid:     channelIDToBytes(scid),
×
6380
                        NodeKey1: nodeKey1[:],
×
6381
                        NodeKey2: nodeKey2[:],
×
6382
                },
×
6383
        )
×
6384
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc