• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 20239727253

15 Dec 2025 04:31PM UTC coverage: 53.405% (-11.8%) from 65.173%
20239727253

Pull #10363

github

web-flow
Merge f9078e552 into 06cc0f3aa
Pull Request #10363: graphdb: add caching for isPublicNode query

4 of 61 new or added lines in 4 files covered. (6.56%)

24136 existing lines in 286 files now uncovered.

110522 of 206950 relevant lines covered (53.41%)

21143.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightninglabs/neutrino/cache"
25
        "github.com/lightninglabs/neutrino/cache/lru"
26
        "github.com/lightningnetwork/lnd/aliasmgr"
27
        "github.com/lightningnetwork/lnd/batch"
28
        "github.com/lightningnetwork/lnd/fn/v2"
29
        "github.com/lightningnetwork/lnd/graph/db/models"
30
        "github.com/lightningnetwork/lnd/lnwire"
31
        "github.com/lightningnetwork/lnd/routing/route"
32
        "github.com/lightningnetwork/lnd/sqldb"
33
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
34
        "github.com/lightningnetwork/lnd/tlv"
35
        "github.com/lightningnetwork/lnd/tor"
36
)
37

38
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
39
// execute queries against the SQL graph tables.
40
//
41
//nolint:ll,interfacebloat
42
type SQLQueries interface {
43
        /*
44
                Node queries.
45
        */
46
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
47
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
48
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
49
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
50
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
51
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
52
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
53
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
54
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
55
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
56
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
57
        DeleteNode(ctx context.Context, id int64) error
58

59
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
60
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
61
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
62
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
63

64
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
65
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
66
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
67
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
68

69
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
70
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
71
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
72
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
73
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
74

75
        /*
76
                Source node queries.
77
        */
78
        AddSourceNode(ctx context.Context, nodeID int64) error
79
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
80

81
        /*
82
                Channel queries.
83
        */
84
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
85
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
86
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
87
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
88
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
89
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
90
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
91
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
92
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
93
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
94
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
95
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
96
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
97
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
98
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
99
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
100
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
101
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
102
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
103
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
104
        DeleteChannels(ctx context.Context, ids []int64) error
105

106
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
107
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
108
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
109
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
110

111
        /*
112
                Channel Policy table queries.
113
        */
114
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
115
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
116
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
117

118
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
119
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
120
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
121

122
        /*
123
                Zombie index queries.
124
        */
125
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
126
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
127
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
128
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
129
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
130
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
131

132
        /*
133
                Prune log table queries.
134
        */
135
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
136
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
137
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
147

148
        /*
149
                Migration specific queries.
150

151
                NOTE: these should not be used in code other than migrations.
152
                Once sqldbv2 is in place, these can be removed from this struct
153
                as then migrations will have their own dedicated queries
154
                structs.
155
        */
156
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
157
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
158
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
159
}
160

161
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
162
// database operations.
163
type BatchedSQLQueries interface {
164
        SQLQueries
165
        sqldb.BatchedTx[SQLQueries]
166
}
167

168
// SQLStore is an implementation of the V1Store interface that uses a SQL
169
// database as the backend.
170
type SQLStore struct {
171
        cfg *SQLStoreConfig
172
        db  BatchedSQLQueries
173

174
        // cacheMu guards all caches (rejectCache and chanCache). If
175
        // this mutex will be acquired at the same time as the DB mutex then
176
        // the cacheMu MUST be acquired first to prevent deadlock.
177
        cacheMu     sync.RWMutex
178
        rejectCache *rejectCache
179
        chanCache   *channelCache
180

181
        publicNodeCache *lru.Cache[[33]byte, *cachedPublicNode]
182

183
        chanScheduler batch.Scheduler[SQLQueries]
184
        nodeScheduler batch.Scheduler[SQLQueries]
185

186
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
187
        srcNodeMu sync.Mutex
188
}
189

190
// cachedPublicNode is a simple wrapper for a boolean value that can be
191
// stored in an LRU cache. The LRU cache requires a Size() method.
192
type cachedPublicNode struct {
193
        isPublic bool
194
}
195

196
// Size returns the size of the cache entry. We return 1 as we just want to
197
// limit the number of entries rather than their actual memory size.
NEW
198
func (c *cachedPublicNode) Size() (uint64, error) {
×
NEW
199
        return 1, nil
×
NEW
200
}
×
201

202
// A compile-time assertion to ensure that SQLStore implements the V1Store
203
// interface.
204
var _ V1Store = (*SQLStore)(nil)
205

206
// SQLStoreConfig holds the configuration for the SQLStore.
207
type SQLStoreConfig struct {
208
        // ChainHash is the genesis hash for the chain that all the gossip
209
        // messages in this store are aimed at.
210
        ChainHash chainhash.Hash
211

212
        // QueryConfig holds configuration values for SQL queries.
213
        QueryCfg *sqldb.QueryConfig
214
}
215

216
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
217
// storage backend.
218
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
219
        options ...StoreOptionModifier) (*SQLStore, error) {
×
220

×
221
        opts := DefaultOptions()
×
222
        for _, o := range options {
×
223
                o(opts)
×
224
        }
×
225

226
        if opts.NoMigration {
×
227
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
228
                        "supported for SQL stores")
×
229
        }
×
230

231
        s := &SQLStore{
×
232
                cfg:         cfg,
×
233
                db:          db,
×
234
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
235
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
NEW
236
                publicNodeCache: lru.NewCache[[33]byte, *cachedPublicNode](
×
NEW
237
                        uint64(opts.PublicNodeCacheSize),
×
NEW
238
                ),
×
NEW
239
                srcNodes: make(map[lnwire.GossipVersion]*srcNodeInfo),
×
240
        }
×
241

×
242
        s.chanScheduler = batch.NewTimeScheduler(
×
243
                db, &s.cacheMu, opts.BatchCommitInterval,
×
244
        )
×
245
        s.nodeScheduler = batch.NewTimeScheduler(
×
246
                db, nil, opts.BatchCommitInterval,
×
247
        )
×
248

×
249
        return s, nil
×
250
}
251

252
// AddNode adds a vertex/node to the graph database. If the node is not
253
// in the database from before, this will add a new, unconnected one to the
254
// graph. If it is present from before, this will update that node's
255
// information.
256
//
257
// NOTE: part of the V1Store interface.
258
func (s *SQLStore) AddNode(ctx context.Context,
259
        node *models.Node, opts ...batch.SchedulerOption) error {
×
260

×
261
        r := &batch.Request[SQLQueries]{
×
262
                Opts: batch.NewSchedulerOptions(opts...),
×
263
                Do: func(queries SQLQueries) error {
×
264
                        _, err := upsertNode(ctx, queries, node)
×
265

×
266
                        // It is possible that two of the same node
×
267
                        // announcements are both being processed in the same
×
268
                        // batch. This may case the UpsertNode conflict to
×
269
                        // be hit since we require at the db layer that the
×
270
                        // new last_update is greater than the existing
×
271
                        // last_update. We need to gracefully handle this here.
×
272
                        if errors.Is(err, sql.ErrNoRows) {
×
273
                                return nil
×
274
                        }
×
275

276
                        return err
×
277
                },
278
        }
279

280
        return s.nodeScheduler.Execute(ctx, r)
×
281
}
282

283
// FetchNode attempts to look up a target node by its identity public
284
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
285
// returned.
286
//
287
// NOTE: part of the V1Store interface.
288
func (s *SQLStore) FetchNode(ctx context.Context,
289
        pubKey route.Vertex) (*models.Node, error) {
×
290

×
291
        var node *models.Node
×
292
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
293
                var err error
×
294
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
295

×
296
                return err
×
297
        }, sqldb.NoOpReset)
×
298
        if err != nil {
×
299
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
300
        }
×
301

302
        return node, nil
×
303
}
304

305
// HasNode determines if the graph has a vertex identified by the
306
// target node identity public key. If the node exists in the database, a
307
// timestamp of when the data for the node was lasted updated is returned along
308
// with a true boolean. Otherwise, an empty time.Time is returned with a false
309
// boolean.
310
//
311
// NOTE: part of the V1Store interface.
312
func (s *SQLStore) HasNode(ctx context.Context,
313
        pubKey [33]byte) (time.Time, bool, error) {
×
314

×
315
        var (
×
316
                exists     bool
×
317
                lastUpdate time.Time
×
318
        )
×
319
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
320
                dbNode, err := db.GetNodeByPubKey(
×
321
                        ctx, sqlc.GetNodeByPubKeyParams{
×
322
                                Version: int16(lnwire.GossipVersion1),
×
323
                                PubKey:  pubKey[:],
×
324
                        },
×
325
                )
×
326
                if errors.Is(err, sql.ErrNoRows) {
×
327
                        return nil
×
328
                } else if err != nil {
×
329
                        return fmt.Errorf("unable to fetch node: %w", err)
×
330
                }
×
331

332
                exists = true
×
333

×
334
                if dbNode.LastUpdate.Valid {
×
335
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
336
                }
×
337

338
                return nil
×
339
        }, sqldb.NoOpReset)
340
        if err != nil {
×
341
                return time.Time{}, false,
×
342
                        fmt.Errorf("unable to fetch node: %w", err)
×
343
        }
×
344

345
        return lastUpdate, exists, nil
×
346
}
347

348
// AddrsForNode returns all known addresses for the target node public key
349
// that the graph DB is aware of. The returned boolean indicates if the
350
// given node is unknown to the graph DB or not.
351
//
352
// NOTE: part of the V1Store interface.
353
func (s *SQLStore) AddrsForNode(ctx context.Context,
354
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
355

×
356
        var (
×
357
                addresses []net.Addr
×
358
                known     bool
×
359
        )
×
360
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
361
                // First, check if the node exists and get its DB ID if it
×
362
                // does.
×
363
                dbID, err := db.GetNodeIDByPubKey(
×
364
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
365
                                Version: int16(lnwire.GossipVersion1),
×
366
                                PubKey:  nodePub.SerializeCompressed(),
×
367
                        },
×
368
                )
×
369
                if errors.Is(err, sql.ErrNoRows) {
×
370
                        return nil
×
371
                }
×
372

373
                known = true
×
374

×
375
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
376
                if err != nil {
×
377
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
378
                                err)
×
379
                }
×
380

381
                return nil
×
382
        }, sqldb.NoOpReset)
383
        if err != nil {
×
384
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
385
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
386
        }
×
387

388
        return known, addresses, nil
×
389
}
390

391
// DeleteNode starts a new database transaction to remove a vertex/node
392
// from the database according to the node's public key.
393
//
394
// NOTE: part of the V1Store interface.
395
func (s *SQLStore) DeleteNode(ctx context.Context,
396
        pubKey route.Vertex) error {
×
397

×
398
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
399
                res, err := db.DeleteNodeByPubKey(
×
400
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
401
                                Version: int16(lnwire.GossipVersion1),
×
402
                                PubKey:  pubKey[:],
×
403
                        },
×
404
                )
×
405
                if err != nil {
×
406
                        return err
×
407
                }
×
408

409
                rows, err := res.RowsAffected()
×
410
                if err != nil {
×
411
                        return err
×
412
                }
×
413

414
                if rows == 0 {
×
415
                        return ErrGraphNodeNotFound
×
416
                } else if rows > 1 {
×
417
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
418
                }
×
419

420
                return err
×
421
        }, sqldb.NoOpReset)
422
        if err != nil {
×
423
                return fmt.Errorf("unable to delete node: %w", err)
×
424
        }
×
425

NEW
426
        s.removePublicNodeCache(pubKey)
×
NEW
427

×
UNCOV
428
        return nil
×
429
}
430

431
// FetchNodeFeatures returns the features of the given node. If no features are
432
// known for the node, an empty feature vector is returned.
433
//
434
// NOTE: this is part of the graphdb.NodeTraverser interface.
435
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
436
        *lnwire.FeatureVector, error) {
×
437

×
438
        ctx := context.TODO()
×
439

×
440
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
441
}
×
442

443
// DisabledChannelIDs returns the channel ids of disabled channels.
444
// A channel is disabled when two of the associated ChanelEdgePolicies
445
// have their disabled bit on.
446
//
447
// NOTE: part of the V1Store interface.
448
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
449
        var (
×
450
                ctx     = context.TODO()
×
451
                chanIDs []uint64
×
452
        )
×
453
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
454
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
455
                if err != nil {
×
456
                        return fmt.Errorf("unable to fetch disabled "+
×
457
                                "channels: %w", err)
×
458
                }
×
459

460
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
461

×
462
                return nil
×
463
        }, sqldb.NoOpReset)
464
        if err != nil {
×
465
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
466
                        err)
×
467
        }
×
468

469
        return chanIDs, nil
×
470
}
471

472
// LookupAlias attempts to return the alias as advertised by the target node.
473
//
474
// NOTE: part of the V1Store interface.
475
func (s *SQLStore) LookupAlias(ctx context.Context,
476
        pub *btcec.PublicKey) (string, error) {
×
477

×
478
        var alias string
×
479
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
480
                dbNode, err := db.GetNodeByPubKey(
×
481
                        ctx, sqlc.GetNodeByPubKeyParams{
×
482
                                Version: int16(lnwire.GossipVersion1),
×
483
                                PubKey:  pub.SerializeCompressed(),
×
484
                        },
×
485
                )
×
486
                if errors.Is(err, sql.ErrNoRows) {
×
487
                        return ErrNodeAliasNotFound
×
488
                } else if err != nil {
×
489
                        return fmt.Errorf("unable to fetch node: %w", err)
×
490
                }
×
491

492
                if !dbNode.Alias.Valid {
×
493
                        return ErrNodeAliasNotFound
×
494
                }
×
495

496
                alias = dbNode.Alias.String
×
497

×
498
                return nil
×
499
        }, sqldb.NoOpReset)
500
        if err != nil {
×
501
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
502
        }
×
503

504
        return alias, nil
×
505
}
506

507
// SourceNode returns the source node of the graph. The source node is treated
508
// as the center node within a star-graph. This method may be used to kick off
509
// a path finding algorithm in order to explore the reachability of another
510
// node based off the source node.
511
//
512
// NOTE: part of the V1Store interface.
513
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
514
        error) {
×
515

×
516
        var node *models.Node
×
517
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
518
                _, nodePub, err := s.getSourceNode(
×
519
                        ctx, db, lnwire.GossipVersion1,
×
520
                )
×
521
                if err != nil {
×
522
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
523
                                err)
×
524
                }
×
525

526
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
527

×
528
                return err
×
529
        }, sqldb.NoOpReset)
530
        if err != nil {
×
531
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
532
        }
×
533

534
        return node, nil
×
535
}
536

537
// SetSourceNode sets the source node within the graph database. The source
538
// node is to be used as the center of a star-graph within path finding
539
// algorithms.
540
//
541
// NOTE: part of the V1Store interface.
542
func (s *SQLStore) SetSourceNode(ctx context.Context,
543
        node *models.Node) error {
×
544

×
545
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
546
                // For the source node, we use a less strict upsert that allows
×
547
                // updates even when the timestamp hasn't changed. This handles
×
548
                // the race condition where multiple goroutines (e.g.,
×
549
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
550
                // same old timestamp, independently increment it, and try to
×
551
                // write concurrently. We want all parameter changes to persist,
×
552
                // even if timestamps collide.
×
553
                id, err := upsertSourceNode(ctx, db, node)
×
554
                if err != nil {
×
555
                        return fmt.Errorf("unable to upsert source node: %w",
×
556
                                err)
×
557
                }
×
558

559
                // Make sure that if a source node for this version is already
560
                // set, then the ID is the same as the one we are about to set.
561
                dbSourceNodeID, _, err := s.getSourceNode(
×
562
                        ctx, db, lnwire.GossipVersion1,
×
563
                )
×
564
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
565
                        return fmt.Errorf("unable to fetch source node: %w",
×
566
                                err)
×
567
                } else if err == nil {
×
568
                        if dbSourceNodeID != id {
×
569
                                return fmt.Errorf("v1 source node already "+
×
570
                                        "set to a different node: %d vs %d",
×
571
                                        dbSourceNodeID, id)
×
572
                        }
×
573

574
                        return nil
×
575
                }
576

577
                return db.AddSourceNode(ctx, id)
×
578
        }, sqldb.NoOpReset)
579
}
580

581
// NodeUpdatesInHorizon returns all the known lightning node which have an
582
// update timestamp within the passed range. This method can be used by two
583
// nodes to quickly determine if they have the same set of up to date node
584
// announcements.
585
//
586
// NOTE: This is part of the V1Store interface.
587
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
588
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
589

×
590
        cfg := defaultIteratorConfig()
×
591
        for _, opt := range opts {
×
592
                opt(cfg)
×
593
        }
×
594

595
        return func(yield func(*models.Node, error) bool) {
×
596
                var (
×
597
                        ctx            = context.TODO()
×
598
                        lastUpdateTime sql.NullInt64
×
599
                        lastPubKey     = make([]byte, 33)
×
600
                        hasMore        = true
×
601
                )
×
602

×
603
                // Each iteration, we'll read a batch amount of nodes, yield
×
604
                // them, then decide is we have more or not.
×
605
                for hasMore {
×
606
                        var batch []*models.Node
×
607

×
608
                        //nolint:ll
×
609
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
610
                                //nolint:ll
×
611
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
612
                                        StartTime: sqldb.SQLInt64(
×
613
                                                startTime.Unix(),
×
614
                                        ),
×
615
                                        EndTime: sqldb.SQLInt64(
×
616
                                                endTime.Unix(),
×
617
                                        ),
×
618
                                        LastUpdate: lastUpdateTime,
×
619
                                        LastPubKey: lastPubKey,
×
620
                                        OnlyPublic: sql.NullBool{
×
621
                                                Bool:  cfg.iterPublicNodes,
×
622
                                                Valid: true,
×
623
                                        },
×
624
                                        MaxResults: sqldb.SQLInt32(
×
625
                                                cfg.nodeUpdateIterBatchSize,
×
626
                                        ),
×
627
                                }
×
628
                                rows, err := db.GetNodesByLastUpdateRange(
×
629
                                        ctx, params,
×
630
                                )
×
631
                                if err != nil {
×
632
                                        return err
×
633
                                }
×
634

635
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
636

×
637
                                err = forEachNodeInBatch(
×
638
                                        ctx, s.cfg.QueryCfg, db, rows,
×
639
                                        func(_ int64, node *models.Node) error {
×
640
                                                batch = append(batch, node)
×
641

×
642
                                                // Update pagination cursors
×
643
                                                // based on the last processed
×
644
                                                // node.
×
645
                                                lastUpdateTime = sql.NullInt64{
×
646
                                                        Int64: node.LastUpdate.
×
647
                                                                Unix(),
×
648
                                                        Valid: true,
×
649
                                                }
×
650
                                                lastPubKey = node.PubKeyBytes[:]
×
651

×
652
                                                return nil
×
653
                                        },
×
654
                                )
655
                                if err != nil {
×
656
                                        return fmt.Errorf("unable to build "+
×
657
                                                "nodes: %w", err)
×
658
                                }
×
659

660
                                return nil
×
661
                        }, func() {
×
662
                                batch = []*models.Node{}
×
663
                        })
×
664

665
                        if err != nil {
×
666
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
667
                                        "error: %v", err)
×
668

×
669
                                yield(&models.Node{}, err)
×
670

×
671
                                return
×
672
                        }
×
673

674
                        for _, node := range batch {
×
675
                                if !yield(node, nil) {
×
676
                                        return
×
677
                                }
×
678
                        }
679

680
                        // If the batch didn't yield anything, then we're done.
681
                        if len(batch) == 0 {
×
682
                                break
×
683
                        }
684
                }
685
        }
686
}
687

688
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
689
// undirected edge from the two target nodes are created. The information stored
690
// denotes the static attributes of the channel, such as the channelID, the keys
691
// involved in creation of the channel, and the set of features that the channel
692
// supports. The chanPoint and chanID are used to uniquely identify the edge
693
// globally within the database.
694
//
695
// NOTE: part of the V1Store interface.
696
func (s *SQLStore) AddChannelEdge(ctx context.Context,
697
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
698

×
699
        var alreadyExists bool
×
700
        r := &batch.Request[SQLQueries]{
×
701
                Opts: batch.NewSchedulerOptions(opts...),
×
702
                Reset: func() {
×
703
                        alreadyExists = false
×
704
                },
×
705
                Do: func(tx SQLQueries) error {
×
706
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
707

×
708
                        // Make sure that the channel doesn't already exist. We
×
709
                        // do this explicitly instead of relying on catching a
×
710
                        // unique constraint error because relying on SQL to
×
711
                        // throw that error would abort the entire batch of
×
712
                        // transactions.
×
713
                        _, err := tx.GetChannelBySCID(
×
714
                                ctx, sqlc.GetChannelBySCIDParams{
×
715
                                        Scid:    chanIDB,
×
716
                                        Version: int16(lnwire.GossipVersion1),
×
717
                                },
×
718
                        )
×
719
                        if err == nil {
×
720
                                alreadyExists = true
×
721
                                return nil
×
722
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
723
                                return fmt.Errorf("unable to fetch channel: %w",
×
724
                                        err)
×
725
                        }
×
726

727
                        return insertChannel(ctx, tx, edge)
×
728
                },
729
                OnCommit: func(err error) error {
×
730
                        switch {
×
731
                        case err != nil:
×
732
                                return err
×
733
                        case alreadyExists:
×
734
                                return ErrEdgeAlreadyExist
×
735
                        default:
×
736
                                s.rejectCache.remove(edge.ChannelID)
×
737
                                s.chanCache.remove(edge.ChannelID)
×
NEW
738
                                s.removePublicNodeCache(
×
NEW
739
                                        edge.NodeKey1Bytes, edge.NodeKey2Bytes,
×
NEW
740
                                )
×
NEW
741

×
UNCOV
742
                                return nil
×
743
                        }
744
                },
745
        }
746

747
        return s.chanScheduler.Execute(ctx, r)
×
748
}
749

750
// HighestChanID returns the "highest" known channel ID in the channel graph.
751
// This represents the "newest" channel from the PoV of the chain. This method
752
// can be used by peers to quickly determine if their graphs are in sync.
753
//
754
// NOTE: This is part of the V1Store interface.
755
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
756
        var highestChanID uint64
×
757
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
758
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
759
                if errors.Is(err, sql.ErrNoRows) {
×
760
                        return nil
×
761
                } else if err != nil {
×
762
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
763
                                err)
×
764
                }
×
765

766
                highestChanID = byteOrder.Uint64(chanID)
×
767

×
768
                return nil
×
769
        }, sqldb.NoOpReset)
770
        if err != nil {
×
771
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
772
        }
×
773

774
        return highestChanID, nil
×
775
}
776

777
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
778
// within the database for the referenced channel. The `flags` attribute within
779
// the ChannelEdgePolicy determines which of the directed edges are being
780
// updated. If the flag is 1, then the first node's information is being
781
// updated, otherwise it's the second node's information. The node ordering is
782
// determined by the lexicographical ordering of the identity public keys of the
783
// nodes on either side of the channel.
784
//
785
// NOTE: part of the V1Store interface.
786
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
787
        edge *models.ChannelEdgePolicy,
788
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
789

×
790
        var (
×
791
                isUpdate1    bool
×
792
                edgeNotFound bool
×
793
                from, to     route.Vertex
×
794
        )
×
795

×
796
        r := &batch.Request[SQLQueries]{
×
797
                Opts: batch.NewSchedulerOptions(opts...),
×
798
                Reset: func() {
×
799
                        isUpdate1 = false
×
800
                        edgeNotFound = false
×
801
                },
×
802
                Do: func(tx SQLQueries) error {
×
803
                        var err error
×
804
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
805
                                ctx, tx, edge,
×
806
                        )
×
807
                        // It is possible that two of the same policy
×
808
                        // announcements are both being processed in the same
×
809
                        // batch. This may case the UpsertEdgePolicy conflict to
×
810
                        // be hit since we require at the db layer that the
×
811
                        // new last_update is greater than the existing
×
812
                        // last_update. We need to gracefully handle this here.
×
813
                        if errors.Is(err, sql.ErrNoRows) {
×
814
                                return nil
×
815
                        } else if err != nil {
×
816
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
817
                        }
×
818

819
                        // Silence ErrEdgeNotFound so that the batch can
820
                        // succeed, but propagate the error via local state.
821
                        if errors.Is(err, ErrEdgeNotFound) {
×
822
                                edgeNotFound = true
×
823
                                return nil
×
824
                        }
×
825

826
                        return err
×
827
                },
828
                OnCommit: func(err error) error {
×
829
                        switch {
×
830
                        case err != nil:
×
831
                                return err
×
832
                        case edgeNotFound:
×
833
                                return ErrEdgeNotFound
×
834
                        default:
×
835
                                s.updateEdgeCache(edge, isUpdate1)
×
836
                                return nil
×
837
                        }
838
                },
839
        }
840

841
        err := s.chanScheduler.Execute(ctx, r)
×
842

×
843
        return from, to, err
×
844
}
845

846
// updateEdgeCache updates our reject and channel caches with the new
847
// edge policy information.
848
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
849
        isUpdate1 bool) {
×
850

×
851
        // If an entry for this channel is found in reject cache, we'll modify
×
852
        // the entry with the updated timestamp for the direction that was just
×
853
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
854
        // during the next query for this edge.
×
855
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
856
                if isUpdate1 {
×
857
                        entry.upd1Time = e.LastUpdate.Unix()
×
858
                } else {
×
859
                        entry.upd2Time = e.LastUpdate.Unix()
×
860
                }
×
861
                s.rejectCache.insert(e.ChannelID, entry)
×
862
        }
863

864
        // If an entry for this channel is found in channel cache, we'll modify
865
        // the entry with the updated policy for the direction that was just
866
        // written. If the edge doesn't exist, we'll defer loading the info and
867
        // policies and lazily read from disk during the next query.
868
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
869
                if isUpdate1 {
×
870
                        channel.Policy1 = e
×
871
                } else {
×
872
                        channel.Policy2 = e
×
873
                }
×
874
                s.chanCache.insert(e.ChannelID, channel)
×
875
        }
876
}
877

878
// ForEachSourceNodeChannel iterates through all channels of the source node,
879
// executing the passed callback on each. The call-back is provided with the
880
// channel's outpoint, whether we have a policy for the channel and the channel
881
// peer's node information.
882
//
883
// NOTE: part of the V1Store interface.
884
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
885
        cb func(chanPoint wire.OutPoint, havePolicy bool,
886
                otherNode *models.Node) error, reset func()) error {
×
887

×
888
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
889
                nodeID, nodePub, err := s.getSourceNode(
×
890
                        ctx, db, lnwire.GossipVersion1,
×
891
                )
×
892
                if err != nil {
×
893
                        return fmt.Errorf("unable to fetch source node: %w",
×
894
                                err)
×
895
                }
×
896

897
                return forEachNodeChannel(
×
898
                        ctx, db, s.cfg, nodeID,
×
899
                        func(info *models.ChannelEdgeInfo,
×
900
                                outPolicy *models.ChannelEdgePolicy,
×
901
                                _ *models.ChannelEdgePolicy) error {
×
902

×
903
                                // Fetch the other node.
×
904
                                var (
×
905
                                        otherNodePub [33]byte
×
906
                                        node1        = info.NodeKey1Bytes
×
907
                                        node2        = info.NodeKey2Bytes
×
908
                                )
×
909
                                switch {
×
910
                                case bytes.Equal(node1[:], nodePub[:]):
×
911
                                        otherNodePub = node2
×
912
                                case bytes.Equal(node2[:], nodePub[:]):
×
913
                                        otherNodePub = node1
×
914
                                default:
×
915
                                        return fmt.Errorf("node not " +
×
916
                                                "participating in this channel")
×
917
                                }
918

919
                                _, otherNode, err := getNodeByPubKey(
×
920
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
921
                                )
×
922
                                if err != nil {
×
923
                                        return fmt.Errorf("unable to fetch "+
×
924
                                                "other node(%x): %w",
×
925
                                                otherNodePub, err)
×
926
                                }
×
927

928
                                return cb(
×
929
                                        info.ChannelPoint, outPolicy != nil,
×
930
                                        otherNode,
×
931
                                )
×
932
                        },
933
                )
934
        }, reset)
935
}
936

937
// ForEachNode iterates through all the stored vertices/nodes in the graph,
938
// executing the passed callback with each node encountered. If the callback
939
// returns an error, then the transaction is aborted and the iteration stops
940
// early.
941
//
942
// NOTE: part of the V1Store interface.
943
func (s *SQLStore) ForEachNode(ctx context.Context,
944
        cb func(node *models.Node) error, reset func()) error {
×
945

×
946
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
947
                return forEachNodePaginated(
×
948
                        ctx, s.cfg.QueryCfg, db,
×
949
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
950
                                node *models.Node) error {
×
951

×
952
                                return cb(node)
×
953
                        },
×
954
                )
955
        }, reset)
956
}
957

958
// ForEachNodeDirectedChannel iterates through all channels of a given node,
959
// executing the passed callback on the directed edge representing the channel
960
// and its incoming policy. If the callback returns an error, then the iteration
961
// is halted with the error propagated back up to the caller.
962
//
963
// Unknown policies are passed into the callback as nil values.
964
//
965
// NOTE: this is part of the graphdb.NodeTraverser interface.
966
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
967
        cb func(channel *DirectedChannel) error, reset func()) error {
×
968

×
969
        var ctx = context.TODO()
×
970

×
971
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
972
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
973
        }, reset)
×
974
}
975

976
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
977
// graph, executing the passed callback with each node encountered. If the
978
// callback returns an error, then the transaction is aborted and the iteration
979
// stops early.
980
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
981
        cb func(route.Vertex, *lnwire.FeatureVector) error,
982
        reset func()) error {
×
983

×
984
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
985
                return forEachNodeCacheable(
×
986
                        ctx, s.cfg.QueryCfg, db,
×
987
                        func(_ int64, nodePub route.Vertex,
×
988
                                features *lnwire.FeatureVector) error {
×
989

×
990
                                return cb(nodePub, features)
×
991
                        },
×
992
                )
993
        }, reset)
994
        if err != nil {
×
995
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
996
        }
×
997

998
        return nil
×
999
}
1000

1001
// ForEachNodeChannel iterates through all channels of the given node,
1002
// executing the passed callback with an edge info structure and the policies
1003
// of each end of the channel. The first edge policy is the outgoing edge *to*
1004
// the connecting node, while the second is the incoming edge *from* the
1005
// connecting node. If the callback returns an error, then the iteration is
1006
// halted with the error propagated back up to the caller.
1007
//
1008
// Unknown policies are passed into the callback as nil values.
1009
//
1010
// NOTE: part of the V1Store interface.
1011
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
1012
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1013
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1014

×
1015
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1016
                dbNode, err := db.GetNodeByPubKey(
×
1017
                        ctx, sqlc.GetNodeByPubKeyParams{
×
1018
                                Version: int16(lnwire.GossipVersion1),
×
1019
                                PubKey:  nodePub[:],
×
1020
                        },
×
1021
                )
×
1022
                if errors.Is(err, sql.ErrNoRows) {
×
1023
                        return nil
×
1024
                } else if err != nil {
×
1025
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1026
                }
×
1027

1028
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1029
        }, reset)
1030
}
1031

1032
// extractMaxUpdateTime returns the maximum of the two policy update times.
1033
// This is used for pagination cursor tracking.
1034
func extractMaxUpdateTime(
1035
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1036

×
1037
        switch {
×
1038
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1039
                return max(row.Policy1LastUpdate.Int64,
×
1040
                        row.Policy2LastUpdate.Int64)
×
1041
        case row.Policy1LastUpdate.Valid:
×
1042
                return row.Policy1LastUpdate.Int64
×
1043
        case row.Policy2LastUpdate.Valid:
×
1044
                return row.Policy2LastUpdate.Int64
×
1045
        default:
×
1046
                return 0
×
1047
        }
1048
}
1049

1050
// buildChannelFromRow constructs a ChannelEdge from a database row.
1051
// This includes building the nodes, channel info, and policies.
1052
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1053
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1054

×
1055
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1056
        if err != nil {
×
1057
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1058
                        err)
×
1059
        }
×
1060

1061
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1062
        if err != nil {
×
1063
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1064
                        err)
×
1065
        }
×
1066

1067
        channel, err := getAndBuildEdgeInfo(
×
1068
                ctx, s.cfg, db,
×
1069
                row.GraphChannel, node1.PubKeyBytes,
×
1070
                node2.PubKeyBytes,
×
1071
        )
×
1072
        if err != nil {
×
1073
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1074
                        "channel info: %w", err)
×
1075
        }
×
1076

1077
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1078
        if err != nil {
×
1079
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1080
                        "channel policies: %w", err)
×
1081
        }
×
1082

1083
        p1, p2, err := getAndBuildChanPolicies(
×
1084
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1085
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1086
        )
×
1087
        if err != nil {
×
1088
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1089
                        "channel policies: %w", err)
×
1090
        }
×
1091

1092
        return ChannelEdge{
×
1093
                Info:    channel,
×
1094
                Policy1: p1,
×
1095
                Policy2: p2,
×
1096
                Node1:   node1,
×
1097
                Node2:   node2,
×
1098
        }, nil
×
1099
}
1100

1101
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1102
// This method acquires the cache lock only once for the entire batch.
1103
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1104
        if len(edgesToCache) == 0 {
×
1105
                return
×
1106
        }
×
1107

1108
        s.cacheMu.Lock()
×
1109
        defer s.cacheMu.Unlock()
×
1110

×
1111
        for chanID, edge := range edgesToCache {
×
1112
                s.chanCache.insert(chanID, edge)
×
1113
        }
×
1114
}
1115

1116
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1117
// one edge that has an update timestamp within the specified horizon.
1118
//
1119
// Iterator Lifecycle:
1120
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1121
// 2. Query batch of channels with policies in time range
1122
// 3. For each channel: check if seen, check cache, or build from DB
1123
// 4. Yield channels to caller
1124
// 5. Update cache after successful batch
1125
// 6. Repeat with updated pagination cursor until no more results
1126
//
1127
// NOTE: This is part of the V1Store interface.
1128
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1129
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1130

×
1131
        // Apply options.
×
1132
        cfg := defaultIteratorConfig()
×
1133
        for _, opt := range opts {
×
1134
                opt(cfg)
×
1135
        }
×
1136

1137
        return func(yield func(ChannelEdge, error) bool) {
×
1138
                var (
×
1139
                        ctx            = context.TODO()
×
1140
                        edgesSeen      = make(map[uint64]struct{})
×
1141
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1142
                        hits           int
×
1143
                        total          int
×
1144
                        lastUpdateTime sql.NullInt64
×
1145
                        lastID         sql.NullInt64
×
1146
                        hasMore        = true
×
1147
                )
×
1148

×
1149
                // Each iteration, we'll read a batch amount of channel updates
×
1150
                // (consulting the cache along the way), yield them, then loop
×
1151
                // back to decide if we have any more updates to read out.
×
1152
                for hasMore {
×
1153
                        var batch []ChannelEdge
×
1154

×
1155
                        // Acquire read lock before starting transaction to
×
1156
                        // ensure consistent lock ordering (cacheMu -> DB) and
×
1157
                        // prevent deadlock with write operations.
×
1158
                        s.cacheMu.RLock()
×
1159

×
1160
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1161
                                func(db SQLQueries) error {
×
1162
                                        //nolint:ll
×
1163
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1164
                                                Version: int16(lnwire.GossipVersion1),
×
1165
                                                StartTime: sqldb.SQLInt64(
×
1166
                                                        startTime.Unix(),
×
1167
                                                ),
×
1168
                                                EndTime: sqldb.SQLInt64(
×
1169
                                                        endTime.Unix(),
×
1170
                                                ),
×
1171
                                                LastUpdateTime: lastUpdateTime,
×
1172
                                                LastID:         lastID,
×
1173
                                                MaxResults: sql.NullInt32{
×
1174
                                                        Int32: int32(
×
1175
                                                                cfg.chanUpdateIterBatchSize,
×
1176
                                                        ),
×
1177
                                                        Valid: true,
×
1178
                                                },
×
1179
                                        }
×
1180
                                        //nolint:ll
×
1181
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1182
                                                ctx, params,
×
1183
                                        )
×
1184
                                        if err != nil {
×
1185
                                                return err
×
1186
                                        }
×
1187

1188
                                        //nolint:ll
1189
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1190

×
1191
                                        //nolint:ll
×
1192
                                        for _, row := range rows {
×
1193
                                                lastUpdateTime = sql.NullInt64{
×
1194
                                                        Int64: extractMaxUpdateTime(row),
×
1195
                                                        Valid: true,
×
1196
                                                }
×
1197
                                                lastID = sql.NullInt64{
×
1198
                                                        Int64: row.GraphChannel.ID,
×
1199
                                                        Valid: true,
×
1200
                                                }
×
1201

×
1202
                                                // Skip if we've already
×
1203
                                                // processed this channel.
×
1204
                                                chanIDInt := byteOrder.Uint64(
×
1205
                                                        row.GraphChannel.Scid,
×
1206
                                                )
×
1207
                                                _, ok := edgesSeen[chanIDInt]
×
1208
                                                if ok {
×
1209
                                                        continue
×
1210
                                                }
1211

1212
                                                // Check cache (we already hold
1213
                                                // shared read lock).
1214
                                                channel, ok := s.chanCache.get(
×
1215
                                                        chanIDInt,
×
1216
                                                )
×
1217
                                                if ok {
×
1218
                                                        hits++
×
1219
                                                        total++
×
1220
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1221
                                                        batch = append(batch, channel)
×
1222

×
1223
                                                        continue
×
1224
                                                }
1225

1226
                                                chanEdge, err := s.buildChannelFromRow(
×
1227
                                                        ctx, db, row,
×
1228
                                                )
×
1229
                                                if err != nil {
×
1230
                                                        return err
×
1231
                                                }
×
1232

1233
                                                edgesSeen[chanIDInt] = struct{}{}
×
1234
                                                edgesToCache[chanIDInt] = chanEdge
×
1235

×
1236
                                                batch = append(batch, chanEdge)
×
1237

×
1238
                                                total++
×
1239
                                        }
1240

1241
                                        return nil
×
1242
                                }, func() {
×
1243
                                        batch = nil
×
1244
                                        edgesSeen = make(map[uint64]struct{})
×
1245
                                        edgesToCache = make(
×
1246
                                                map[uint64]ChannelEdge,
×
1247
                                        )
×
1248
                                })
×
1249

1250
                        // Release read lock after transaction completes.
1251
                        s.cacheMu.RUnlock()
×
1252

×
1253
                        if err != nil {
×
1254
                                log.Errorf("ChanUpdatesInHorizon "+
×
1255
                                        "batch error: %v", err)
×
1256

×
1257
                                yield(ChannelEdge{}, err)
×
1258

×
1259
                                return
×
1260
                        }
×
1261

1262
                        for _, edge := range batch {
×
1263
                                if !yield(edge, nil) {
×
1264
                                        return
×
1265
                                }
×
1266
                        }
1267

1268
                        // Update cache after successful batch yield, setting
1269
                        // the cache lock only once for the entire batch.
1270
                        s.updateChanCacheBatch(edgesToCache)
×
1271
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1272

×
1273
                        // If the batch didn't yield anything, then we're done.
×
1274
                        if len(batch) == 0 {
×
1275
                                break
×
1276
                        }
1277
                }
1278

1279
                if total > 0 {
×
1280
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1281
                                "%.2f (%d/%d)",
×
1282
                                float64(hits)*100/float64(total), hits, total)
×
1283
                } else {
×
1284
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1285
                                "in horizon (%s, %s)", startTime, endTime)
×
1286
                }
×
1287
        }
1288
}
1289

1290
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1291
// data to the call-back. If withAddrs is true, then the call-back will also be
1292
// provided with the addresses associated with the node. The address retrieval
1293
// result in an additional round-trip to the database, so it should only be used
1294
// if the addresses are actually needed.
1295
//
1296
// NOTE: part of the V1Store interface.
1297
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1298
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1299
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1300

×
1301
        type nodeCachedBatchData struct {
×
1302
                features      map[int64][]int
×
1303
                addrs         map[int64][]nodeAddress
×
1304
                chanBatchData *batchChannelData
×
1305
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1306
        }
×
1307

×
1308
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1309
                // pageQueryFunc is used to query the next page of nodes.
×
1310
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1311
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1312

×
1313
                        return db.ListNodeIDsAndPubKeys(
×
1314
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1315
                                        Version: int16(lnwire.GossipVersion1),
×
1316
                                        ID:      lastID,
×
1317
                                        Limit:   limit,
×
1318
                                },
×
1319
                        )
×
1320
                }
×
1321

1322
                // batchDataFunc is then used to batch load the data required
1323
                // for each page of nodes.
1324
                batchDataFunc := func(ctx context.Context,
×
1325
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1326

×
1327
                        // Batch load node features.
×
1328
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1329
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1330
                        )
×
1331
                        if err != nil {
×
1332
                                return nil, fmt.Errorf("unable to batch load "+
×
1333
                                        "node features: %w", err)
×
1334
                        }
×
1335

1336
                        // Maybe fetch the node's addresses if requested.
1337
                        var nodeAddrs map[int64][]nodeAddress
×
1338
                        if withAddrs {
×
1339
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1340
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1341
                                )
×
1342
                                if err != nil {
×
1343
                                        return nil, fmt.Errorf("unable to "+
×
1344
                                                "batch load node "+
×
1345
                                                "addresses: %w", err)
×
1346
                                }
×
1347
                        }
1348

1349
                        // Batch load ALL unique channels for ALL nodes in this
1350
                        // page.
1351
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1352
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1353
                                        Version:  int16(lnwire.GossipVersion1),
×
1354
                                        Node1Ids: nodeIDs,
×
1355
                                        Node2Ids: nodeIDs,
×
1356
                                },
×
1357
                        )
×
1358
                        if err != nil {
×
1359
                                return nil, fmt.Errorf("unable to batch "+
×
1360
                                        "fetch channels for nodes: %w", err)
×
1361
                        }
×
1362

1363
                        // Deduplicate channels and collect IDs.
1364
                        var (
×
1365
                                allChannelIDs []int64
×
1366
                                allPolicyIDs  []int64
×
1367
                        )
×
1368
                        uniqueChannels := make(
×
1369
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1370
                        )
×
1371

×
1372
                        for _, channel := range allChannels {
×
1373
                                channelID := channel.GraphChannel.ID
×
1374

×
1375
                                // Only process each unique channel once.
×
1376
                                _, exists := uniqueChannels[channelID]
×
1377
                                if exists {
×
1378
                                        continue
×
1379
                                }
1380

1381
                                uniqueChannels[channelID] = channel
×
1382
                                allChannelIDs = append(allChannelIDs, channelID)
×
1383

×
1384
                                if channel.Policy1ID.Valid {
×
1385
                                        allPolicyIDs = append(
×
1386
                                                allPolicyIDs,
×
1387
                                                channel.Policy1ID.Int64,
×
1388
                                        )
×
1389
                                }
×
1390
                                if channel.Policy2ID.Valid {
×
1391
                                        allPolicyIDs = append(
×
1392
                                                allPolicyIDs,
×
1393
                                                channel.Policy2ID.Int64,
×
1394
                                        )
×
1395
                                }
×
1396
                        }
1397

1398
                        // Batch load channel data for all unique channels.
1399
                        channelBatchData, err := batchLoadChannelData(
×
1400
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1401
                                allPolicyIDs,
×
1402
                        )
×
1403
                        if err != nil {
×
1404
                                return nil, fmt.Errorf("unable to batch "+
×
1405
                                        "load channel data: %w", err)
×
1406
                        }
×
1407

1408
                        // Create map of node ID to channels that involve this
1409
                        // node.
1410
                        nodeIDSet := make(map[int64]bool)
×
1411
                        for _, nodeID := range nodeIDs {
×
1412
                                nodeIDSet[nodeID] = true
×
1413
                        }
×
1414

1415
                        nodeChannelMap := make(
×
1416
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1417
                        )
×
1418
                        for _, channel := range uniqueChannels {
×
1419
                                // Add channel to both nodes if they're in our
×
1420
                                // current page.
×
1421
                                node1 := channel.GraphChannel.NodeID1
×
1422
                                if nodeIDSet[node1] {
×
1423
                                        nodeChannelMap[node1] = append(
×
1424
                                                nodeChannelMap[node1], channel,
×
1425
                                        )
×
1426
                                }
×
1427
                                node2 := channel.GraphChannel.NodeID2
×
1428
                                if nodeIDSet[node2] {
×
1429
                                        nodeChannelMap[node2] = append(
×
1430
                                                nodeChannelMap[node2], channel,
×
1431
                                        )
×
1432
                                }
×
1433
                        }
1434

1435
                        return &nodeCachedBatchData{
×
1436
                                features:      nodeFeatures,
×
1437
                                addrs:         nodeAddrs,
×
1438
                                chanBatchData: channelBatchData,
×
1439
                                chanMap:       nodeChannelMap,
×
1440
                        }, nil
×
1441
                }
1442

1443
                // processItem is used to process each node in the current page.
1444
                processItem := func(ctx context.Context,
×
1445
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1446
                        batchData *nodeCachedBatchData) error {
×
1447

×
1448
                        // Build feature vector for this node.
×
1449
                        fv := lnwire.EmptyFeatureVector()
×
1450
                        features, exists := batchData.features[nodeData.ID]
×
1451
                        if exists {
×
1452
                                for _, bit := range features {
×
1453
                                        fv.Set(lnwire.FeatureBit(bit))
×
1454
                                }
×
1455
                        }
1456

1457
                        var nodePub route.Vertex
×
1458
                        copy(nodePub[:], nodeData.PubKey)
×
1459

×
1460
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1461

×
1462
                        toNodeCallback := func() route.Vertex {
×
1463
                                return nodePub
×
1464
                        }
×
1465

1466
                        // Build cached channels map for this node.
1467
                        channels := make(map[uint64]*DirectedChannel)
×
1468
                        for _, channelRow := range nodeChannels {
×
1469
                                directedChan, err := buildDirectedChannel(
×
1470
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1471
                                        channelRow, batchData.chanBatchData, fv,
×
1472
                                        toNodeCallback,
×
1473
                                )
×
1474
                                if err != nil {
×
1475
                                        return err
×
1476
                                }
×
1477

1478
                                channels[directedChan.ChannelID] = directedChan
×
1479
                        }
1480

1481
                        addrs, err := buildNodeAddresses(
×
1482
                                batchData.addrs[nodeData.ID],
×
1483
                        )
×
1484
                        if err != nil {
×
1485
                                return fmt.Errorf("unable to build node "+
×
1486
                                        "addresses: %w", err)
×
1487
                        }
×
1488

1489
                        return cb(ctx, nodePub, addrs, channels)
×
1490
                }
1491

1492
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1493
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1494
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1495
                                return node.ID
×
1496
                        },
×
1497
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1498
                                error) {
×
1499

×
1500
                                return node.ID, nil
×
1501
                        },
×
1502
                        batchDataFunc, processItem,
1503
                )
1504
        }, reset)
1505
}
1506

1507
// ForEachChannelCacheable iterates through all the channel edges stored
1508
// within the graph and invokes the passed callback for each edge. The
1509
// callback takes two edges as since this is a directed graph, both the
1510
// in/out edges are visited. If the callback returns an error, then the
1511
// transaction is aborted and the iteration stops early.
1512
//
1513
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1514
// pointer for that particular channel edge routing policy will be
1515
// passed into the callback.
1516
//
1517
// NOTE: this method is like ForEachChannel but fetches only the data
1518
// required for the graph cache.
1519
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1520
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1521
        reset func()) error {
×
1522

×
1523
        ctx := context.TODO()
×
1524

×
1525
        handleChannel := func(_ context.Context,
×
1526
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1527

×
1528
                node1, node2, err := buildNodeVertices(
×
1529
                        row.Node1Pubkey, row.Node2Pubkey,
×
1530
                )
×
1531
                if err != nil {
×
1532
                        return err
×
1533
                }
×
1534

1535
                edge := buildCacheableChannelInfo(
×
1536
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1537
                )
×
1538

×
1539
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1540
                if err != nil {
×
1541
                        return err
×
1542
                }
×
1543

1544
                pol1, pol2, err := buildCachedChanPolicies(
×
1545
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1546
                )
×
1547
                if err != nil {
×
1548
                        return err
×
1549
                }
×
1550

1551
                return cb(edge, pol1, pol2)
×
1552
        }
1553

1554
        extractCursor := func(
×
1555
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1556

×
1557
                return row.ID
×
1558
        }
×
1559

1560
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1561
                //nolint:ll
×
1562
                queryFunc := func(ctx context.Context, lastID int64,
×
1563
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1564
                        error) {
×
1565

×
1566
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1567
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1568
                                        Version: int16(lnwire.GossipVersion1),
×
1569
                                        ID:      lastID,
×
1570
                                        Limit:   limit,
×
1571
                                },
×
1572
                        )
×
1573
                }
×
1574

1575
                return sqldb.ExecutePaginatedQuery(
×
1576
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1577
                        extractCursor, handleChannel,
×
1578
                )
×
1579
        }, reset)
1580
}
1581

1582
// ForEachChannel iterates through all the channel edges stored within the
1583
// graph and invokes the passed callback for each edge. The callback takes two
1584
// edges as since this is a directed graph, both the in/out edges are visited.
1585
// If the callback returns an error, then the transaction is aborted and the
1586
// iteration stops early.
1587
//
1588
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1589
// for that particular channel edge routing policy will be passed into the
1590
// callback.
1591
//
1592
// NOTE: part of the V1Store interface.
1593
func (s *SQLStore) ForEachChannel(ctx context.Context,
1594
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1595
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1596

×
1597
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1598
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1599
        }, reset)
×
1600
}
1601

1602
// FilterChannelRange returns the channel ID's of all known channels which were
1603
// mined in a block height within the passed range. The channel IDs are grouped
1604
// by their common block height. This method can be used to quickly share with a
1605
// peer the set of channels we know of within a particular range to catch them
1606
// up after a period of time offline. If withTimestamps is true then the
1607
// timestamp info of the latest received channel update messages of the channel
1608
// will be included in the response.
1609
//
1610
// NOTE: This is part of the V1Store interface.
1611
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1612
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1613

×
1614
        var (
×
1615
                ctx       = context.TODO()
×
1616
                startSCID = &lnwire.ShortChannelID{
×
1617
                        BlockHeight: startHeight,
×
1618
                }
×
1619
                endSCID = lnwire.ShortChannelID{
×
1620
                        BlockHeight: endHeight,
×
1621
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1622
                        TxPosition:  math.MaxUint16,
×
1623
                }
×
1624
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1625
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1626
        )
×
1627

×
1628
        // 1) get all channels where channelID is between start and end chan ID.
×
1629
        // 2) skip if not public (ie, no channel_proof)
×
1630
        // 3) collect that channel.
×
1631
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1632
        //    and add those timestamps to the collected channel.
×
1633
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1634
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1635
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1636
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1637
                                StartScid: chanIDStart,
×
1638
                                EndScid:   chanIDEnd,
×
1639
                        },
×
1640
                )
×
1641
                if err != nil {
×
1642
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1643
                                err)
×
1644
                }
×
1645

1646
                for _, dbChan := range dbChans {
×
1647
                        cid := lnwire.NewShortChanIDFromInt(
×
1648
                                byteOrder.Uint64(dbChan.Scid),
×
1649
                        )
×
1650
                        chanInfo := NewChannelUpdateInfo(
×
1651
                                cid, time.Time{}, time.Time{},
×
1652
                        )
×
1653

×
1654
                        if !withTimestamps {
×
1655
                                channelsPerBlock[cid.BlockHeight] = append(
×
1656
                                        channelsPerBlock[cid.BlockHeight],
×
1657
                                        chanInfo,
×
1658
                                )
×
1659

×
1660
                                continue
×
1661
                        }
1662

1663
                        //nolint:ll
1664
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1665
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1666
                                        Version:   int16(lnwire.GossipVersion1),
×
1667
                                        ChannelID: dbChan.ID,
×
1668
                                        NodeID:    dbChan.NodeID1,
×
1669
                                },
×
1670
                        )
×
1671
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1672
                                return fmt.Errorf("unable to fetch node1 "+
×
1673
                                        "policy: %w", err)
×
1674
                        } else if err == nil {
×
1675
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1676
                                        node1Policy.LastUpdate.Int64, 0,
×
1677
                                )
×
1678
                        }
×
1679

1680
                        //nolint:ll
1681
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1682
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1683
                                        Version:   int16(lnwire.GossipVersion1),
×
1684
                                        ChannelID: dbChan.ID,
×
1685
                                        NodeID:    dbChan.NodeID2,
×
1686
                                },
×
1687
                        )
×
1688
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1689
                                return fmt.Errorf("unable to fetch node2 "+
×
1690
                                        "policy: %w", err)
×
1691
                        } else if err == nil {
×
1692
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1693
                                        node2Policy.LastUpdate.Int64, 0,
×
1694
                                )
×
1695
                        }
×
1696

1697
                        channelsPerBlock[cid.BlockHeight] = append(
×
1698
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1699
                        )
×
1700
                }
1701

1702
                return nil
×
1703
        }, func() {
×
1704
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1705
        })
×
1706
        if err != nil {
×
1707
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1708
        }
×
1709

1710
        if len(channelsPerBlock) == 0 {
×
1711
                return nil, nil
×
1712
        }
×
1713

1714
        // Return the channel ranges in ascending block height order.
1715
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1716
        slices.Sort(blocks)
×
1717

×
1718
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1719
                return BlockChannelRange{
×
1720
                        Height:   block,
×
1721
                        Channels: channelsPerBlock[block],
×
1722
                }
×
1723
        }), nil
×
1724
}
1725

1726
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1727
// zombie. This method is used on an ad-hoc basis, when channels need to be
1728
// marked as zombies outside the normal pruning cycle.
1729
//
1730
// NOTE: part of the V1Store interface.
1731
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1732
        pubKey1, pubKey2 [33]byte) error {
×
1733

×
1734
        ctx := context.TODO()
×
1735

×
1736
        s.cacheMu.Lock()
×
1737
        defer s.cacheMu.Unlock()
×
1738

×
1739
        chanIDB := channelIDToBytes(chanID)
×
1740

×
1741
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1742
                return db.UpsertZombieChannel(
×
1743
                        ctx, sqlc.UpsertZombieChannelParams{
×
1744
                                Version:  int16(lnwire.GossipVersion1),
×
1745
                                Scid:     chanIDB,
×
1746
                                NodeKey1: pubKey1[:],
×
1747
                                NodeKey2: pubKey2[:],
×
1748
                        },
×
1749
                )
×
1750
        }, sqldb.NoOpReset)
×
1751
        if err != nil {
×
1752
                return fmt.Errorf("unable to upsert zombie channel "+
×
1753
                        "(channel_id=%d): %w", chanID, err)
×
1754
        }
×
1755

1756
        s.rejectCache.remove(chanID)
×
1757
        s.chanCache.remove(chanID)
×
NEW
1758
        s.removePublicNodeCache(pubKey1, pubKey2)
×
1759

×
1760
        return nil
×
1761
}
1762

1763
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1764
//
1765
// NOTE: part of the V1Store interface.
1766
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1767
        s.cacheMu.Lock()
×
1768
        defer s.cacheMu.Unlock()
×
1769

×
1770
        var (
×
1771
                ctx     = context.TODO()
×
1772
                chanIDB = channelIDToBytes(chanID)
×
1773
        )
×
1774

×
1775
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1776
                res, err := db.DeleteZombieChannel(
×
1777
                        ctx, sqlc.DeleteZombieChannelParams{
×
1778
                                Scid:    chanIDB,
×
1779
                                Version: int16(lnwire.GossipVersion1),
×
1780
                        },
×
1781
                )
×
1782
                if err != nil {
×
1783
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1784
                                err)
×
1785
                }
×
1786

1787
                rows, err := res.RowsAffected()
×
1788
                if err != nil {
×
1789
                        return err
×
1790
                }
×
1791

1792
                if rows == 0 {
×
1793
                        return ErrZombieEdgeNotFound
×
1794
                } else if rows > 1 {
×
1795
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1796
                                "expected 1", rows)
×
1797
                }
×
1798

1799
                return nil
×
1800
        }, sqldb.NoOpReset)
1801
        if err != nil {
×
1802
                return fmt.Errorf("unable to mark edge live "+
×
1803
                        "(channel_id=%d): %w", chanID, err)
×
1804
        }
×
1805

1806
        s.rejectCache.remove(chanID)
×
1807
        s.chanCache.remove(chanID)
×
1808

×
1809
        return err
×
1810
}
1811

1812
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1813
// zombie, then the two node public keys corresponding to this edge are also
1814
// returned.
1815
//
1816
// NOTE: part of the V1Store interface.
1817
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1818
        error) {
×
1819

×
1820
        var (
×
1821
                ctx              = context.TODO()
×
1822
                isZombie         bool
×
1823
                pubKey1, pubKey2 route.Vertex
×
1824
                chanIDB          = channelIDToBytes(chanID)
×
1825
        )
×
1826

×
1827
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1828
                zombie, err := db.GetZombieChannel(
×
1829
                        ctx, sqlc.GetZombieChannelParams{
×
1830
                                Scid:    chanIDB,
×
1831
                                Version: int16(lnwire.GossipVersion1),
×
1832
                        },
×
1833
                )
×
1834
                if errors.Is(err, sql.ErrNoRows) {
×
1835
                        return nil
×
1836
                }
×
1837
                if err != nil {
×
1838
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1839
                                err)
×
1840
                }
×
1841

1842
                copy(pubKey1[:], zombie.NodeKey1)
×
1843
                copy(pubKey2[:], zombie.NodeKey2)
×
1844
                isZombie = true
×
1845

×
1846
                return nil
×
1847
        }, sqldb.NoOpReset)
1848
        if err != nil {
×
1849
                return false, route.Vertex{}, route.Vertex{},
×
1850
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1851
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1852
        }
×
1853

1854
        return isZombie, pubKey1, pubKey2, nil
×
1855
}
1856

1857
// NumZombies returns the current number of zombie channels in the graph.
1858
//
1859
// NOTE: part of the V1Store interface.
1860
func (s *SQLStore) NumZombies() (uint64, error) {
×
1861
        var (
×
1862
                ctx        = context.TODO()
×
1863
                numZombies uint64
×
1864
        )
×
1865
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1866
                count, err := db.CountZombieChannels(
×
1867
                        ctx, int16(lnwire.GossipVersion1),
×
1868
                )
×
1869
                if err != nil {
×
1870
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1871
                                err)
×
1872
                }
×
1873

1874
                numZombies = uint64(count)
×
1875

×
1876
                return nil
×
1877
        }, sqldb.NoOpReset)
1878
        if err != nil {
×
1879
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1880
        }
×
1881

1882
        return numZombies, nil
×
1883
}
1884

1885
// DeleteChannelEdges removes edges with the given channel IDs from the
1886
// database and marks them as zombies. This ensures that we're unable to re-add
1887
// it to our database once again. If an edge does not exist within the
1888
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1889
// true, then when we mark these edges as zombies, we'll set up the keys such
1890
// that we require the node that failed to send the fresh update to be the one
1891
// that resurrects the channel from its zombie state. The markZombie bool
1892
// denotes whether to mark the channel as a zombie.
1893
//
1894
// NOTE: part of the V1Store interface.
1895
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1896
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1897

×
1898
        s.cacheMu.Lock()
×
1899
        defer s.cacheMu.Unlock()
×
1900

×
1901
        // Keep track of which channels we end up finding so that we can
×
1902
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1903
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1904
        for _, chanID := range chanIDs {
×
1905
                chanLookup[chanID] = struct{}{}
×
1906
        }
×
1907

1908
        var (
×
1909
                ctx   = context.TODO()
×
1910
                edges []*models.ChannelEdgeInfo
×
1911
        )
×
1912
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1913
                // First, collect all channel rows.
×
1914
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1915
                chanCallBack := func(ctx context.Context,
×
1916
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1917

×
1918
                        // Deleting the entry from the map indicates that we
×
1919
                        // have found the channel.
×
1920
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1921
                        delete(chanLookup, scid)
×
1922

×
1923
                        channelRows = append(channelRows, row)
×
1924

×
1925
                        return nil
×
1926
                }
×
1927

1928
                err := s.forEachChanWithPoliciesInSCIDList(
×
1929
                        ctx, db, chanCallBack, chanIDs,
×
1930
                )
×
1931
                if err != nil {
×
1932
                        return err
×
1933
                }
×
1934

1935
                if len(chanLookup) > 0 {
×
1936
                        return ErrEdgeNotFound
×
1937
                }
×
1938

1939
                if len(channelRows) == 0 {
×
1940
                        return nil
×
1941
                }
×
1942

1943
                // Batch build all channel edges.
1944
                var chanIDsToDelete []int64
×
1945
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1946
                        ctx, s.cfg, db, channelRows,
×
1947
                )
×
1948
                if err != nil {
×
1949
                        return err
×
1950
                }
×
1951

1952
                if markZombie {
×
1953
                        for i, row := range channelRows {
×
1954
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1955

×
1956
                                err := handleZombieMarking(
×
1957
                                        ctx, db, row, edges[i],
×
1958
                                        strictZombiePruning, scid,
×
1959
                                )
×
1960
                                if err != nil {
×
1961
                                        return fmt.Errorf("unable to mark "+
×
1962
                                                "channel as zombie: %w", err)
×
1963
                                }
×
1964
                        }
1965
                }
1966

1967
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1968
        }, func() {
×
1969
                edges = nil
×
1970

×
1971
                // Re-fill the lookup map.
×
1972
                for _, chanID := range chanIDs {
×
1973
                        chanLookup[chanID] = struct{}{}
×
1974
                }
×
1975
        })
1976
        if err != nil {
×
1977
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1978
                        err)
×
1979
        }
×
1980

1981
        for _, chanID := range chanIDs {
×
1982
                s.rejectCache.remove(chanID)
×
1983
                s.chanCache.remove(chanID)
×
1984
        }
×
1985

NEW
1986
        var pubkeys [][33]byte
×
NEW
1987
        for _, edge := range edges {
×
NEW
1988
                pubkeys = append(
×
NEW
1989
                        pubkeys, edge.NodeKey1Bytes, edge.NodeKey2Bytes,
×
NEW
1990
                )
×
NEW
1991
        }
×
NEW
1992
        s.removePublicNodeCache(pubkeys...)
×
NEW
1993

×
UNCOV
1994
        return edges, nil
×
1995
}
1996

1997
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1998
// channel identified by the channel ID. If the channel can't be found, then
1999
// ErrEdgeNotFound is returned. A struct which houses the general information
2000
// for the channel itself is returned as well as two structs that contain the
2001
// routing policies for the channel in either direction.
2002
//
2003
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
2004
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
2005
// the ChannelEdgeInfo will only include the public keys of each node.
2006
//
2007
// NOTE: part of the V1Store interface.
2008
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
2009
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2010
        *models.ChannelEdgePolicy, error) {
×
2011

×
2012
        var (
×
2013
                ctx              = context.TODO()
×
2014
                edge             *models.ChannelEdgeInfo
×
2015
                policy1, policy2 *models.ChannelEdgePolicy
×
2016
                chanIDB          = channelIDToBytes(chanID)
×
2017
        )
×
2018
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2019
                row, err := db.GetChannelBySCIDWithPolicies(
×
2020
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2021
                                Scid:    chanIDB,
×
2022
                                Version: int16(lnwire.GossipVersion1),
×
2023
                        },
×
2024
                )
×
2025
                if errors.Is(err, sql.ErrNoRows) {
×
2026
                        // First check if this edge is perhaps in the zombie
×
2027
                        // index.
×
2028
                        zombie, err := db.GetZombieChannel(
×
2029
                                ctx, sqlc.GetZombieChannelParams{
×
2030
                                        Scid:    chanIDB,
×
2031
                                        Version: int16(lnwire.GossipVersion1),
×
2032
                                },
×
2033
                        )
×
2034
                        if errors.Is(err, sql.ErrNoRows) {
×
2035
                                return ErrEdgeNotFound
×
2036
                        } else if err != nil {
×
2037
                                return fmt.Errorf("unable to check if "+
×
2038
                                        "channel is zombie: %w", err)
×
2039
                        }
×
2040

2041
                        // At this point, we know the channel is a zombie, so
2042
                        // we'll return an error indicating this, and we will
2043
                        // populate the edge info with the public keys of each
2044
                        // party as this is the only information we have about
2045
                        // it.
2046
                        edge = &models.ChannelEdgeInfo{}
×
2047
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
2048
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
2049

×
2050
                        return ErrZombieEdge
×
2051
                } else if err != nil {
×
2052
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2053
                }
×
2054

2055
                node1, node2, err := buildNodeVertices(
×
2056
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2057
                )
×
2058
                if err != nil {
×
2059
                        return err
×
2060
                }
×
2061

2062
                edge, err = getAndBuildEdgeInfo(
×
2063
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2064
                )
×
2065
                if err != nil {
×
2066
                        return fmt.Errorf("unable to build channel info: %w",
×
2067
                                err)
×
2068
                }
×
2069

2070
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2071
                if err != nil {
×
2072
                        return fmt.Errorf("unable to extract channel "+
×
2073
                                "policies: %w", err)
×
2074
                }
×
2075

2076
                policy1, policy2, err = getAndBuildChanPolicies(
×
2077
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2078
                        node1, node2,
×
2079
                )
×
2080
                if err != nil {
×
2081
                        return fmt.Errorf("unable to build channel "+
×
2082
                                "policies: %w", err)
×
2083
                }
×
2084

2085
                return nil
×
2086
        }, sqldb.NoOpReset)
2087
        if err != nil {
×
2088
                // If we are returning the ErrZombieEdge, then we also need to
×
2089
                // return the edge info as the method comment indicates that
×
2090
                // this will be populated when the edge is a zombie.
×
2091
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2092
                        err)
×
2093
        }
×
2094

2095
        return edge, policy1, policy2, nil
×
2096
}
2097

2098
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2099
// the channel identified by the funding outpoint. If the channel can't be
2100
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2101
// information for the channel itself is returned as well as two structs that
2102
// contain the routing policies for the channel in either direction.
2103
//
2104
// NOTE: part of the V1Store interface.
2105
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2106
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2107
        *models.ChannelEdgePolicy, error) {
×
2108

×
2109
        var (
×
2110
                ctx              = context.TODO()
×
2111
                edge             *models.ChannelEdgeInfo
×
2112
                policy1, policy2 *models.ChannelEdgePolicy
×
2113
        )
×
2114
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2115
                row, err := db.GetChannelByOutpointWithPolicies(
×
2116
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2117
                                Outpoint: op.String(),
×
2118
                                Version:  int16(lnwire.GossipVersion1),
×
2119
                        },
×
2120
                )
×
2121
                if errors.Is(err, sql.ErrNoRows) {
×
2122
                        return ErrEdgeNotFound
×
2123
                } else if err != nil {
×
2124
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2125
                }
×
2126

2127
                node1, node2, err := buildNodeVertices(
×
2128
                        row.Node1Pubkey, row.Node2Pubkey,
×
2129
                )
×
2130
                if err != nil {
×
2131
                        return err
×
2132
                }
×
2133

2134
                edge, err = getAndBuildEdgeInfo(
×
2135
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2136
                )
×
2137
                if err != nil {
×
2138
                        return fmt.Errorf("unable to build channel info: %w",
×
2139
                                err)
×
2140
                }
×
2141

2142
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2143
                if err != nil {
×
2144
                        return fmt.Errorf("unable to extract channel "+
×
2145
                                "policies: %w", err)
×
2146
                }
×
2147

2148
                policy1, policy2, err = getAndBuildChanPolicies(
×
2149
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2150
                        node1, node2,
×
2151
                )
×
2152
                if err != nil {
×
2153
                        return fmt.Errorf("unable to build channel "+
×
2154
                                "policies: %w", err)
×
2155
                }
×
2156

2157
                return nil
×
2158
        }, sqldb.NoOpReset)
2159
        if err != nil {
×
2160
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2161
                        err)
×
2162
        }
×
2163

2164
        return edge, policy1, policy2, nil
×
2165
}
2166

2167
// HasChannelEdge returns true if the database knows of a channel edge with the
2168
// passed channel ID, and false otherwise. If an edge with that ID is found
2169
// within the graph, then two time stamps representing the last time the edge
2170
// was updated for both directed edges are returned along with the boolean. If
2171
// it is not found, then the zombie index is checked and its result is returned
2172
// as the second boolean.
2173
//
2174
// NOTE: part of the V1Store interface.
2175
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2176
        bool, error) {
×
2177

×
2178
        ctx := context.TODO()
×
2179

×
2180
        var (
×
2181
                exists          bool
×
2182
                isZombie        bool
×
2183
                node1LastUpdate time.Time
×
2184
                node2LastUpdate time.Time
×
2185
        )
×
2186

×
2187
        // We'll query the cache with the shared lock held to allow multiple
×
2188
        // readers to access values in the cache concurrently if they exist.
×
2189
        s.cacheMu.RLock()
×
2190
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2191
                s.cacheMu.RUnlock()
×
2192
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2193
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2194
                exists, isZombie = entry.flags.unpack()
×
2195

×
2196
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2197
        }
×
2198
        s.cacheMu.RUnlock()
×
2199

×
2200
        s.cacheMu.Lock()
×
2201
        defer s.cacheMu.Unlock()
×
2202

×
2203
        // The item was not found with the shared lock, so we'll acquire the
×
2204
        // exclusive lock and check the cache again in case another method added
×
2205
        // the entry to the cache while no lock was held.
×
2206
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2207
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2208
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2209
                exists, isZombie = entry.flags.unpack()
×
2210

×
2211
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2212
        }
×
2213

2214
        chanIDB := channelIDToBytes(chanID)
×
2215
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2216
                channel, err := db.GetChannelBySCID(
×
2217
                        ctx, sqlc.GetChannelBySCIDParams{
×
2218
                                Scid:    chanIDB,
×
2219
                                Version: int16(lnwire.GossipVersion1),
×
2220
                        },
×
2221
                )
×
2222
                if errors.Is(err, sql.ErrNoRows) {
×
2223
                        // Check if it is a zombie channel.
×
2224
                        isZombie, err = db.IsZombieChannel(
×
2225
                                ctx, sqlc.IsZombieChannelParams{
×
2226
                                        Scid:    chanIDB,
×
2227
                                        Version: int16(lnwire.GossipVersion1),
×
2228
                                },
×
2229
                        )
×
2230
                        if err != nil {
×
2231
                                return fmt.Errorf("could not check if channel "+
×
2232
                                        "is zombie: %w", err)
×
2233
                        }
×
2234

2235
                        return nil
×
2236
                } else if err != nil {
×
2237
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2238
                }
×
2239

2240
                exists = true
×
2241

×
2242
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2243
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2244
                                Version:   int16(lnwire.GossipVersion1),
×
2245
                                ChannelID: channel.ID,
×
2246
                                NodeID:    channel.NodeID1,
×
2247
                        },
×
2248
                )
×
2249
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2250
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2251
                                err)
×
2252
                } else if err == nil {
×
2253
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2254
                }
×
2255

2256
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2257
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2258
                                Version:   int16(lnwire.GossipVersion1),
×
2259
                                ChannelID: channel.ID,
×
2260
                                NodeID:    channel.NodeID2,
×
2261
                        },
×
2262
                )
×
2263
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2264
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2265
                                err)
×
2266
                } else if err == nil {
×
2267
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2268
                }
×
2269

2270
                return nil
×
2271
        }, sqldb.NoOpReset)
2272
        if err != nil {
×
2273
                return time.Time{}, time.Time{}, false, false,
×
2274
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2275
        }
×
2276

2277
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2278
                upd1Time: node1LastUpdate.Unix(),
×
2279
                upd2Time: node2LastUpdate.Unix(),
×
2280
                flags:    packRejectFlags(exists, isZombie),
×
2281
        })
×
2282

×
2283
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2284
}
2285

2286
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2287
// passed channel point (outpoint). If the passed channel doesn't exist within
2288
// the database, then ErrEdgeNotFound is returned.
2289
//
2290
// NOTE: part of the V1Store interface.
2291
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2292
        var (
×
2293
                ctx       = context.TODO()
×
2294
                channelID uint64
×
2295
        )
×
2296
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2297
                chanID, err := db.GetSCIDByOutpoint(
×
2298
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2299
                                Outpoint: chanPoint.String(),
×
2300
                                Version:  int16(lnwire.GossipVersion1),
×
2301
                        },
×
2302
                )
×
2303
                if errors.Is(err, sql.ErrNoRows) {
×
2304
                        return ErrEdgeNotFound
×
2305
                } else if err != nil {
×
2306
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2307
                                err)
×
2308
                }
×
2309

2310
                channelID = byteOrder.Uint64(chanID)
×
2311

×
2312
                return nil
×
2313
        }, sqldb.NoOpReset)
2314
        if err != nil {
×
2315
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2316
        }
×
2317

2318
        return channelID, nil
×
2319
}
2320

2321
// IsPublicNode is a helper method that determines whether the node with the
2322
// given public key is seen as a public node in the graph from the graph's
2323
// source node's point of view.
2324
//
2325
// NOTE: part of the V1Store interface.
2326
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2327
        ctx := context.TODO()
×
2328

×
NEW
2329
        // Check the cache first and return early if there is a hit.
×
NEW
2330
        cached, err := s.publicNodeCache.Get(pubKey)
×
NEW
2331
        if err == nil && cached != nil {
×
NEW
2332
                return cached.isPublic, nil
×
NEW
2333
        }
×
2334

2335
        // Log any error other than NotFound.
NEW
2336
        if err != nil && !errors.Is(err, cache.ErrElementNotFound) {
×
NEW
2337
                log.Warnf("Unable to check cache if node is public: %v", err)
×
NEW
2338
        }
×
2339

2340
        var isPublic bool
×
NEW
2341
        err = s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2342
                var err error
×
2343
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2344

×
2345
                return err
×
2346
        }, sqldb.NoOpReset)
×
2347
        if err != nil {
×
2348
                return false, fmt.Errorf("unable to check if node is "+
×
2349
                        "public: %w", err)
×
2350
        }
×
2351

2352
        // Store the result in cache.
NEW
2353
        _, err = s.publicNodeCache.Put(pubKey, &cachedPublicNode{
×
NEW
2354
                isPublic: isPublic,
×
NEW
2355
        })
×
NEW
2356
        if err != nil {
×
NEW
2357
                log.Warnf("Unable to store node info in cache: %v", err)
×
NEW
2358
        }
×
2359

UNCOV
2360
        return isPublic, nil
×
2361
}
2362

2363
// FetchChanInfos returns the set of channel edges that correspond to the passed
2364
// channel ID's. If an edge is the query is unknown to the database, it will
2365
// skipped and the result will contain only those edges that exist at the time
2366
// of the query. This can be used to respond to peer queries that are seeking to
2367
// fill in gaps in their view of the channel graph.
2368
//
2369
// NOTE: part of the V1Store interface.
2370
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2371
        var (
×
2372
                ctx   = context.TODO()
×
2373
                edges = make(map[uint64]ChannelEdge)
×
2374
        )
×
2375
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2376
                // First, collect all channel rows.
×
2377
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2378
                chanCallBack := func(ctx context.Context,
×
2379
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2380

×
2381
                        channelRows = append(channelRows, row)
×
2382
                        return nil
×
2383
                }
×
2384

2385
                err := s.forEachChanWithPoliciesInSCIDList(
×
2386
                        ctx, db, chanCallBack, chanIDs,
×
2387
                )
×
2388
                if err != nil {
×
2389
                        return err
×
2390
                }
×
2391

2392
                if len(channelRows) == 0 {
×
2393
                        return nil
×
2394
                }
×
2395

2396
                // Batch build all channel edges.
2397
                chans, err := batchBuildChannelEdges(
×
2398
                        ctx, s.cfg, db, channelRows,
×
2399
                )
×
2400
                if err != nil {
×
2401
                        return fmt.Errorf("unable to build channel edges: %w",
×
2402
                                err)
×
2403
                }
×
2404

2405
                for _, c := range chans {
×
2406
                        edges[c.Info.ChannelID] = c
×
2407
                }
×
2408

2409
                return err
×
2410
        }, func() {
×
2411
                clear(edges)
×
2412
        })
×
2413
        if err != nil {
×
2414
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2415
        }
×
2416

2417
        res := make([]ChannelEdge, 0, len(edges))
×
2418
        for _, chanID := range chanIDs {
×
2419
                edge, ok := edges[chanID]
×
2420
                if !ok {
×
2421
                        continue
×
2422
                }
2423

2424
                res = append(res, edge)
×
2425
        }
2426

2427
        return res, nil
×
2428
}
2429

2430
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2431
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2432
// channels in a paginated manner.
2433
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2434
        db SQLQueries, cb func(ctx context.Context,
2435
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2436
        chanIDs []uint64) error {
×
2437

×
2438
        queryWrapper := func(ctx context.Context,
×
2439
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2440
                error) {
×
2441

×
2442
                return db.GetChannelsBySCIDWithPolicies(
×
2443
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2444
                                Version: int16(lnwire.GossipVersion1),
×
2445
                                Scids:   scids,
×
2446
                        },
×
2447
                )
×
2448
        }
×
2449

2450
        return sqldb.ExecuteBatchQuery(
×
2451
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2452
                cb,
×
2453
        )
×
2454
}
2455

2456
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2457
// ID's that we don't know and are not known zombies of the passed set. In other
2458
// words, we perform a set difference of our set of chan ID's and the ones
2459
// passed in. This method can be used by callers to determine the set of
2460
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2461
// known zombies is also returned.
2462
//
2463
// NOTE: part of the V1Store interface.
2464
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2465
        []ChannelUpdateInfo, error) {
×
2466

×
2467
        var (
×
2468
                ctx          = context.TODO()
×
2469
                newChanIDs   []uint64
×
2470
                knownZombies []ChannelUpdateInfo
×
2471
                infoLookup   = make(
×
2472
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2473
                )
×
2474
        )
×
2475

×
2476
        // We first build a lookup map of the channel ID's to the
×
2477
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2478
        // already know about.
×
2479
        for _, chanInfo := range chansInfo {
×
2480
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2481
        }
×
2482

2483
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2484
                // The call-back function deletes known channels from
×
2485
                // infoLookup, so that we can later check which channels are
×
2486
                // zombies by only looking at the remaining channels in the set.
×
2487
                cb := func(ctx context.Context,
×
2488
                        channel sqlc.GraphChannel) error {
×
2489

×
2490
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2491

×
2492
                        return nil
×
2493
                }
×
2494

2495
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2496
                if err != nil {
×
2497
                        return fmt.Errorf("unable to iterate through "+
×
2498
                                "channels: %w", err)
×
2499
                }
×
2500

2501
                // We want to ensure that we deal with the channels in the
2502
                // same order that they were passed in, so we iterate over the
2503
                // original chansInfo slice and then check if that channel is
2504
                // still in the infoLookup map.
2505
                for _, chanInfo := range chansInfo {
×
2506
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2507
                        if _, ok := infoLookup[channelID]; !ok {
×
2508
                                continue
×
2509
                        }
2510

2511
                        isZombie, err := db.IsZombieChannel(
×
2512
                                ctx, sqlc.IsZombieChannelParams{
×
2513
                                        Scid:    channelIDToBytes(channelID),
×
2514
                                        Version: int16(lnwire.GossipVersion1),
×
2515
                                },
×
2516
                        )
×
2517
                        if err != nil {
×
2518
                                return fmt.Errorf("unable to fetch zombie "+
×
2519
                                        "channel: %w", err)
×
2520
                        }
×
2521

2522
                        if isZombie {
×
2523
                                knownZombies = append(knownZombies, chanInfo)
×
2524

×
2525
                                continue
×
2526
                        }
2527

2528
                        newChanIDs = append(newChanIDs, channelID)
×
2529
                }
2530

2531
                return nil
×
2532
        }, func() {
×
2533
                newChanIDs = nil
×
2534
                knownZombies = nil
×
2535
                // Rebuild the infoLookup map in case of a rollback.
×
2536
                for _, chanInfo := range chansInfo {
×
2537
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2538
                        infoLookup[scid] = chanInfo
×
2539
                }
×
2540
        })
2541
        if err != nil {
×
2542
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2543
        }
×
2544

2545
        return newChanIDs, knownZombies, nil
×
2546
}
2547

2548
// forEachChanInSCIDList is a helper method that executes a paged query
2549
// against the database to fetch all channels that match the passed
2550
// ChannelUpdateInfo slice. The callback function is called for each channel
2551
// that is found.
2552
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2553
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2554
        chansInfo []ChannelUpdateInfo) error {
×
2555

×
2556
        queryWrapper := func(ctx context.Context,
×
2557
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2558

×
2559
                return db.GetChannelsBySCIDs(
×
2560
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2561
                                Version: int16(lnwire.GossipVersion1),
×
2562
                                Scids:   scids,
×
2563
                        },
×
2564
                )
×
2565
        }
×
2566

2567
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2568
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2569

×
2570
                return channelIDToBytes(channelID)
×
2571
        }
×
2572

2573
        return sqldb.ExecuteBatchQuery(
×
2574
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2575
                cb,
×
2576
        )
×
2577
}
2578

2579
// PruneGraphNodes is a garbage collection method which attempts to prune out
2580
// any nodes from the channel graph that are currently unconnected. This ensure
2581
// that we only maintain a graph of reachable nodes. In the event that a pruned
2582
// node gains more channels, it will be re-added back to the graph.
2583
//
2584
// NOTE: this prunes nodes across protocol versions. It will never prune the
2585
// source nodes.
2586
//
2587
// NOTE: part of the V1Store interface.
2588
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2589
        var ctx = context.TODO()
×
2590

×
2591
        var prunedNodes []route.Vertex
×
2592
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2593
                var err error
×
2594
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2595

×
2596
                return err
×
2597
        }, func() {
×
2598
                prunedNodes = nil
×
2599
        })
×
2600
        if err != nil {
×
2601
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2602
        }
×
2603

2604
        return prunedNodes, nil
×
2605
}
2606

2607
// PruneGraph prunes newly closed channels from the channel graph in response
2608
// to a new block being solved on the network. Any transactions which spend the
2609
// funding output of any known channels within he graph will be deleted.
2610
// Additionally, the "prune tip", or the last block which has been used to
2611
// prune the graph is stored so callers can ensure the graph is fully in sync
2612
// with the current UTXO state. A slice of channels that have been closed by
2613
// the target block along with any pruned nodes are returned if the function
2614
// succeeds without error.
2615
//
2616
// NOTE: part of the V1Store interface.
2617
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2618
        blockHash *chainhash.Hash, blockHeight uint32) (
2619
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2620

×
2621
        ctx := context.TODO()
×
2622

×
2623
        s.cacheMu.Lock()
×
2624
        defer s.cacheMu.Unlock()
×
2625

×
2626
        var (
×
2627
                closedChans []*models.ChannelEdgeInfo
×
2628
                prunedNodes []route.Vertex
×
2629
        )
×
2630
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2631
                // First, collect all channel rows that need to be pruned.
×
2632
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2633
                channelCallback := func(ctx context.Context,
×
2634
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2635

×
2636
                        channelRows = append(channelRows, row)
×
2637

×
2638
                        return nil
×
2639
                }
×
2640

2641
                err := s.forEachChanInOutpoints(
×
2642
                        ctx, db, spentOutputs, channelCallback,
×
2643
                )
×
2644
                if err != nil {
×
2645
                        return fmt.Errorf("unable to fetch channels by "+
×
2646
                                "outpoints: %w", err)
×
2647
                }
×
2648

2649
                if len(channelRows) == 0 {
×
2650
                        // There are no channels to prune. So we can exit early
×
2651
                        // after updating the prune log.
×
2652
                        err = db.UpsertPruneLogEntry(
×
2653
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2654
                                        BlockHash:   blockHash[:],
×
2655
                                        BlockHeight: int64(blockHeight),
×
2656
                                },
×
2657
                        )
×
2658
                        if err != nil {
×
2659
                                return fmt.Errorf("unable to insert prune log "+
×
2660
                                        "entry: %w", err)
×
2661
                        }
×
2662

2663
                        return nil
×
2664
                }
2665

2666
                // Batch build all channel edges for pruning.
2667
                var chansToDelete []int64
×
2668
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2669
                        ctx, s.cfg, db, channelRows,
×
2670
                )
×
2671
                if err != nil {
×
2672
                        return err
×
2673
                }
×
2674

2675
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2676
                if err != nil {
×
2677
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2678
                }
×
2679

2680
                err = db.UpsertPruneLogEntry(
×
2681
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2682
                                BlockHash:   blockHash[:],
×
2683
                                BlockHeight: int64(blockHeight),
×
2684
                        },
×
2685
                )
×
2686
                if err != nil {
×
2687
                        return fmt.Errorf("unable to insert prune log "+
×
2688
                                "entry: %w", err)
×
2689
                }
×
2690

2691
                // Now that we've pruned some channels, we'll also prune any
2692
                // nodes that no longer have any channels.
2693
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2694
                if err != nil {
×
2695
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2696
                                err)
×
2697
                }
×
2698

2699
                return nil
×
2700
        }, func() {
×
2701
                prunedNodes = nil
×
2702
                closedChans = nil
×
2703
        })
×
2704
        if err != nil {
×
2705
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2706
        }
×
2707

2708
        for _, channel := range closedChans {
×
2709
                s.rejectCache.remove(channel.ChannelID)
×
2710
                s.chanCache.remove(channel.ChannelID)
×
NEW
2711
                s.removePublicNodeCache(
×
NEW
2712
                        channel.NodeKey1Bytes, channel.NodeKey2Bytes,
×
NEW
2713
                )
×
UNCOV
2714
        }
×
2715

2716
        return closedChans, prunedNodes, nil
×
2717
}
2718

2719
// forEachChanInOutpoints is a helper function that executes a paginated
2720
// query to fetch channels by their outpoints and applies the given call-back
2721
// to each.
2722
//
2723
// NOTE: this fetches channels for all protocol versions.
2724
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2725
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2726
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2727

×
2728
        // Create a wrapper that uses the transaction's db instance to execute
×
2729
        // the query.
×
2730
        queryWrapper := func(ctx context.Context,
×
2731
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2732
                error) {
×
2733

×
2734
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2735
        }
×
2736

2737
        // Define the conversion function from Outpoint to string.
2738
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2739
                return outpoint.String()
×
2740
        }
×
2741

2742
        return sqldb.ExecuteBatchQuery(
×
2743
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2744
                queryWrapper, cb,
×
2745
        )
×
2746
}
2747

2748
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2749
        dbIDs []int64) error {
×
2750

×
2751
        // Create a wrapper that uses the transaction's db instance to execute
×
2752
        // the query.
×
2753
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2754
                return nil, db.DeleteChannels(ctx, ids)
×
2755
        }
×
2756

2757
        idConverter := func(id int64) int64 {
×
2758
                return id
×
2759
        }
×
2760

2761
        return sqldb.ExecuteBatchQuery(
×
2762
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2763
                queryWrapper, func(ctx context.Context, _ any) error {
×
2764
                        return nil
×
2765
                },
×
2766
        )
2767
}
2768

2769
// ChannelView returns the verifiable edge information for each active channel
2770
// within the known channel graph. The set of UTXOs (along with their scripts)
2771
// returned are the ones that need to be watched on chain to detect channel
2772
// closes on the resident blockchain.
2773
//
2774
// NOTE: part of the V1Store interface.
2775
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2776
        var (
×
2777
                ctx        = context.TODO()
×
2778
                edgePoints []EdgePoint
×
2779
        )
×
2780

×
2781
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2782
                handleChannel := func(_ context.Context,
×
2783
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2784

×
2785
                        pkScript, err := genMultiSigP2WSH(
×
2786
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2787
                        )
×
2788
                        if err != nil {
×
2789
                                return err
×
2790
                        }
×
2791

2792
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2793
                        if err != nil {
×
2794
                                return err
×
2795
                        }
×
2796

2797
                        edgePoints = append(edgePoints, EdgePoint{
×
2798
                                FundingPkScript: pkScript,
×
2799
                                OutPoint:        *op,
×
2800
                        })
×
2801

×
2802
                        return nil
×
2803
                }
2804

2805
                queryFunc := func(ctx context.Context, lastID int64,
×
2806
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2807

×
2808
                        return db.ListChannelsPaginated(
×
2809
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2810
                                        Version: int16(lnwire.GossipVersion1),
×
2811
                                        ID:      lastID,
×
2812
                                        Limit:   limit,
×
2813
                                },
×
2814
                        )
×
2815
                }
×
2816

2817
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2818
                        return row.ID
×
2819
                }
×
2820

2821
                return sqldb.ExecutePaginatedQuery(
×
2822
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2823
                        extractCursor, handleChannel,
×
2824
                )
×
2825
        }, func() {
×
2826
                edgePoints = nil
×
2827
        })
×
2828
        if err != nil {
×
2829
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2830
        }
×
2831

2832
        return edgePoints, nil
×
2833
}
2834

2835
// PruneTip returns the block height and hash of the latest block that has been
2836
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2837
// to tell if the graph is currently in sync with the current best known UTXO
2838
// state.
2839
//
2840
// NOTE: part of the V1Store interface.
2841
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2842
        var (
×
2843
                ctx       = context.TODO()
×
2844
                tipHash   chainhash.Hash
×
2845
                tipHeight uint32
×
2846
        )
×
2847
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2848
                pruneTip, err := db.GetPruneTip(ctx)
×
2849
                if errors.Is(err, sql.ErrNoRows) {
×
2850
                        return ErrGraphNeverPruned
×
2851
                } else if err != nil {
×
2852
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2853
                }
×
2854

2855
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2856
                tipHeight = uint32(pruneTip.BlockHeight)
×
2857

×
2858
                return nil
×
2859
        }, sqldb.NoOpReset)
2860
        if err != nil {
×
2861
                return nil, 0, err
×
2862
        }
×
2863

2864
        return &tipHash, tipHeight, nil
×
2865
}
2866

2867
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2868
//
2869
// NOTE: this prunes nodes across protocol versions. It will never prune the
2870
// source nodes.
2871
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2872
        db SQLQueries) ([]route.Vertex, error) {
×
2873

×
2874
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2875
        if err != nil {
×
2876
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2877
                        "nodes: %w", err)
×
2878
        }
×
2879

2880
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2881
        for i, nodeKey := range nodeKeys {
×
2882
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2883
                if err != nil {
×
2884
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2885
                                "from bytes: %w", err)
×
2886
                }
×
2887

2888
                prunedNodes[i] = pub
×
2889
        }
2890

2891
        return prunedNodes, nil
×
2892
}
2893

2894
// DisconnectBlockAtHeight is used to indicate that the block specified
2895
// by the passed height has been disconnected from the main chain. This
2896
// will "rewind" the graph back to the height below, deleting channels
2897
// that are no longer confirmed from the graph. The prune log will be
2898
// set to the last prune height valid for the remaining chain.
2899
// Channels that were removed from the graph resulting from the
2900
// disconnected block are returned.
2901
//
2902
// NOTE: part of the V1Store interface.
2903
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2904
        []*models.ChannelEdgeInfo, error) {
×
2905

×
2906
        ctx := context.TODO()
×
2907

×
2908
        var (
×
2909
                // Every channel having a ShortChannelID starting at 'height'
×
2910
                // will no longer be confirmed.
×
2911
                startShortChanID = lnwire.ShortChannelID{
×
2912
                        BlockHeight: height,
×
2913
                }
×
2914

×
2915
                // Delete everything after this height from the db up until the
×
2916
                // SCID alias range.
×
2917
                endShortChanID = aliasmgr.StartingAlias
×
2918

×
2919
                removedChans []*models.ChannelEdgeInfo
×
2920

×
2921
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2922
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2923
        )
×
2924

×
2925
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2926
                rows, err := db.GetChannelsBySCIDRange(
×
2927
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2928
                                StartScid: chanIDStart,
×
2929
                                EndScid:   chanIDEnd,
×
2930
                        },
×
2931
                )
×
2932
                if err != nil {
×
2933
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2934
                }
×
2935

2936
                if len(rows) == 0 {
×
2937
                        // No channels to disconnect, but still clean up prune
×
2938
                        // log.
×
2939
                        return db.DeletePruneLogEntriesInRange(
×
2940
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2941
                                        StartHeight: int64(height),
×
2942
                                        EndHeight: int64(
×
2943
                                                endShortChanID.BlockHeight,
×
2944
                                        ),
×
2945
                                },
×
2946
                        )
×
2947
                }
×
2948

2949
                // Batch build all channel edges for disconnection.
2950
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2951
                        ctx, s.cfg, db, rows,
×
2952
                )
×
2953
                if err != nil {
×
2954
                        return err
×
2955
                }
×
2956

2957
                removedChans = channelEdges
×
2958

×
2959
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2960
                if err != nil {
×
2961
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2962
                }
×
2963

2964
                return db.DeletePruneLogEntriesInRange(
×
2965
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2966
                                StartHeight: int64(height),
×
2967
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2968
                        },
×
2969
                )
×
2970
        }, func() {
×
2971
                removedChans = nil
×
2972
        })
×
2973
        if err != nil {
×
2974
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2975
                        "height: %w", err)
×
2976
        }
×
2977

2978
        s.cacheMu.Lock()
×
2979
        for _, channel := range removedChans {
×
2980
                s.rejectCache.remove(channel.ChannelID)
×
2981
                s.chanCache.remove(channel.ChannelID)
×
NEW
2982
                s.removePublicNodeCache(
×
NEW
2983
                        channel.NodeKey1Bytes, channel.NodeKey2Bytes,
×
NEW
2984
                )
×
2985
        }
×
2986
        s.cacheMu.Unlock()
×
2987

×
2988
        return removedChans, nil
×
2989
}
2990

2991
// AddEdgeProof sets the proof of an existing edge in the graph database.
2992
//
2993
// NOTE: part of the V1Store interface.
2994
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2995
        proof *models.ChannelAuthProof) error {
×
2996

×
2997
        var (
×
2998
                ctx       = context.TODO()
×
2999
                scidBytes = channelIDToBytes(scid.ToUint64())
×
3000
        )
×
3001

×
3002
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3003
                res, err := db.AddV1ChannelProof(
×
3004
                        ctx, sqlc.AddV1ChannelProofParams{
×
3005
                                Scid:              scidBytes,
×
3006
                                Node1Signature:    proof.NodeSig1Bytes,
×
3007
                                Node2Signature:    proof.NodeSig2Bytes,
×
3008
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
3009
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
3010
                        },
×
3011
                )
×
3012
                if err != nil {
×
3013
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
3014
                }
×
3015

3016
                n, err := res.RowsAffected()
×
3017
                if err != nil {
×
3018
                        return err
×
3019
                }
×
3020

3021
                if n == 0 {
×
3022
                        return fmt.Errorf("no rows affected when adding edge "+
×
3023
                                "proof for SCID %v", scid)
×
3024
                } else if n > 1 {
×
3025
                        return fmt.Errorf("multiple rows affected when adding "+
×
3026
                                "edge proof for SCID %v: %d rows affected",
×
3027
                                scid, n)
×
3028
                }
×
3029

3030
                return nil
×
3031
        }, sqldb.NoOpReset)
3032
        if err != nil {
×
3033
                return fmt.Errorf("unable to add edge proof: %w", err)
×
3034
        }
×
3035

3036
        return nil
×
3037
}
3038

3039
// PutClosedScid stores a SCID for a closed channel in the database. This is so
3040
// that we can ignore channel announcements that we know to be closed without
3041
// having to validate them and fetch a block.
3042
//
3043
// NOTE: part of the V1Store interface.
3044
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
3045
        var (
×
3046
                ctx     = context.TODO()
×
3047
                chanIDB = channelIDToBytes(scid.ToUint64())
×
3048
        )
×
3049

×
3050
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3051
                return db.InsertClosedChannel(ctx, chanIDB)
×
3052
        }, sqldb.NoOpReset)
×
3053
}
3054

3055
// IsClosedScid checks whether a channel identified by the passed in scid is
3056
// closed. This helps avoid having to perform expensive validation checks.
3057
//
3058
// NOTE: part of the V1Store interface.
3059
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
3060
        var (
×
3061
                ctx      = context.TODO()
×
3062
                isClosed bool
×
3063
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
3064
        )
×
3065
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3066
                var err error
×
3067
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3068
                if err != nil {
×
3069
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3070
                                err)
×
3071
                }
×
3072

3073
                return nil
×
3074
        }, sqldb.NoOpReset)
3075
        if err != nil {
×
3076
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3077
                        err)
×
3078
        }
×
3079

3080
        return isClosed, nil
×
3081
}
3082

3083
// GraphSession will provide the call-back with access to a NodeTraverser
3084
// instance which can be used to perform queries against the channel graph.
3085
//
3086
// NOTE: part of the V1Store interface.
3087
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3088
        reset func()) error {
×
3089

×
3090
        var ctx = context.TODO()
×
3091

×
3092
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3093
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3094
        }, reset)
×
3095
}
3096

3097
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3098
// read only transaction for a consistent view of the graph.
3099
type sqlNodeTraverser struct {
3100
        db    SQLQueries
3101
        chain chainhash.Hash
3102
}
3103

3104
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3105
// NodeTraverser interface.
3106
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3107

3108
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3109
func newSQLNodeTraverser(db SQLQueries,
3110
        chain chainhash.Hash) *sqlNodeTraverser {
×
3111

×
3112
        return &sqlNodeTraverser{
×
3113
                db:    db,
×
3114
                chain: chain,
×
3115
        }
×
3116
}
×
3117

3118
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3119
// node.
3120
//
3121
// NOTE: Part of the NodeTraverser interface.
3122
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3123
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3124

×
3125
        ctx := context.TODO()
×
3126

×
3127
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3128
}
×
3129

3130
// FetchNodeFeatures returns the features of the given node. If the node is
3131
// unknown, assume no additional features are supported.
3132
//
3133
// NOTE: Part of the NodeTraverser interface.
3134
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3135
        *lnwire.FeatureVector, error) {
×
3136

×
3137
        ctx := context.TODO()
×
3138

×
3139
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3140
}
×
3141

3142
// forEachNodeDirectedChannel iterates through all channels of a given
3143
// node, executing the passed callback on the directed edge representing the
3144
// channel and its incoming policy. If the node is not found, no error is
3145
// returned.
3146
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3147
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3148

×
3149
        toNodeCallback := func() route.Vertex {
×
3150
                return nodePub
×
3151
        }
×
3152

3153
        dbID, err := db.GetNodeIDByPubKey(
×
3154
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3155
                        Version: int16(lnwire.GossipVersion1),
×
3156
                        PubKey:  nodePub[:],
×
3157
                },
×
3158
        )
×
3159
        if errors.Is(err, sql.ErrNoRows) {
×
3160
                return nil
×
3161
        } else if err != nil {
×
3162
                return fmt.Errorf("unable to fetch node: %w", err)
×
3163
        }
×
3164

3165
        rows, err := db.ListChannelsByNodeID(
×
3166
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3167
                        Version: int16(lnwire.GossipVersion1),
×
3168
                        NodeID1: dbID,
×
3169
                },
×
3170
        )
×
3171
        if err != nil {
×
3172
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3173
        }
×
3174

3175
        // Exit early if there are no channels for this node so we don't
3176
        // do the unnecessary feature fetching.
3177
        if len(rows) == 0 {
×
3178
                return nil
×
3179
        }
×
3180

3181
        features, err := getNodeFeatures(ctx, db, dbID)
×
3182
        if err != nil {
×
3183
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3184
        }
×
3185

3186
        for _, row := range rows {
×
3187
                node1, node2, err := buildNodeVertices(
×
3188
                        row.Node1Pubkey, row.Node2Pubkey,
×
3189
                )
×
3190
                if err != nil {
×
3191
                        return fmt.Errorf("unable to build node vertices: %w",
×
3192
                                err)
×
3193
                }
×
3194

3195
                edge := buildCacheableChannelInfo(
×
3196
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3197
                        node1, node2,
×
3198
                )
×
3199

×
3200
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3201
                if err != nil {
×
3202
                        return err
×
3203
                }
×
3204

3205
                p1, p2, err := buildCachedChanPolicies(
×
3206
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3207
                )
×
3208
                if err != nil {
×
3209
                        return err
×
3210
                }
×
3211

3212
                // Determine the outgoing and incoming policy for this
3213
                // channel and node combo.
3214
                outPolicy, inPolicy := p1, p2
×
3215
                if p1 != nil && node2 == nodePub {
×
3216
                        outPolicy, inPolicy = p2, p1
×
3217
                } else if p2 != nil && node1 != nodePub {
×
3218
                        outPolicy, inPolicy = p2, p1
×
3219
                }
×
3220

3221
                var cachedInPolicy *models.CachedEdgePolicy
×
3222
                if inPolicy != nil {
×
3223
                        cachedInPolicy = inPolicy
×
3224
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3225
                        cachedInPolicy.ToNodeFeatures = features
×
3226
                }
×
3227

3228
                directedChannel := &DirectedChannel{
×
3229
                        ChannelID:    edge.ChannelID,
×
3230
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3231
                        OtherNode:    edge.NodeKey2Bytes,
×
3232
                        Capacity:     edge.Capacity,
×
3233
                        OutPolicySet: outPolicy != nil,
×
3234
                        InPolicy:     cachedInPolicy,
×
3235
                }
×
3236
                if outPolicy != nil {
×
3237
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3238
                                directedChannel.InboundFee = fee
×
3239
                        })
×
3240
                }
3241

3242
                if nodePub == edge.NodeKey2Bytes {
×
3243
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3244
                }
×
3245

3246
                if err := cb(directedChannel); err != nil {
×
3247
                        return err
×
3248
                }
×
3249
        }
3250

3251
        return nil
×
3252
}
3253

3254
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3255
// and executes the provided callback for each node. It does so via pagination
3256
// along with batch loading of the node feature bits.
3257
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3258
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3259
                features *lnwire.FeatureVector) error) error {
×
3260

×
3261
        handleNode := func(_ context.Context,
×
3262
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3263
                featureBits map[int64][]int) error {
×
3264

×
3265
                fv := lnwire.EmptyFeatureVector()
×
3266
                if features, exists := featureBits[dbNode.ID]; exists {
×
3267
                        for _, bit := range features {
×
3268
                                fv.Set(lnwire.FeatureBit(bit))
×
3269
                        }
×
3270
                }
3271

3272
                var pub route.Vertex
×
3273
                copy(pub[:], dbNode.PubKey)
×
3274

×
3275
                return processNode(dbNode.ID, pub, fv)
×
3276
        }
3277

3278
        queryFunc := func(ctx context.Context, lastID int64,
×
3279
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3280

×
3281
                return db.ListNodeIDsAndPubKeys(
×
3282
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3283
                                Version: int16(lnwire.GossipVersion1),
×
3284
                                ID:      lastID,
×
3285
                                Limit:   limit,
×
3286
                        },
×
3287
                )
×
3288
        }
×
3289

3290
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3291
                return row.ID
×
3292
        }
×
3293

3294
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3295
                return node.ID, nil
×
3296
        }
×
3297

3298
        batchQueryFunc := func(ctx context.Context,
×
3299
                nodeIDs []int64) (map[int64][]int, error) {
×
3300

×
3301
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3302
        }
×
3303

3304
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3305
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3306
                batchQueryFunc, handleNode,
×
3307
        )
×
3308
}
3309

3310
// forEachNodeChannel iterates through all channels of a node, executing
3311
// the passed callback on each. The call-back is provided with the channel's
3312
// edge information, the outgoing policy and the incoming policy for the
3313
// channel and node combo.
3314
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3315
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3316
                *models.ChannelEdgePolicy,
3317
                *models.ChannelEdgePolicy) error) error {
×
3318

×
3319
        // Get all the V1 channels for this node.
×
3320
        rows, err := db.ListChannelsByNodeID(
×
3321
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3322
                        Version: int16(lnwire.GossipVersion1),
×
3323
                        NodeID1: id,
×
3324
                },
×
3325
        )
×
3326
        if err != nil {
×
3327
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3328
        }
×
3329

3330
        // Collect all the channel and policy IDs.
3331
        var (
×
3332
                chanIDs   = make([]int64, 0, len(rows))
×
3333
                policyIDs = make([]int64, 0, 2*len(rows))
×
3334
        )
×
3335
        for _, row := range rows {
×
3336
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3337

×
3338
                if row.Policy1ID.Valid {
×
3339
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3340
                }
×
3341
                if row.Policy2ID.Valid {
×
3342
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3343
                }
×
3344
        }
3345

3346
        batchData, err := batchLoadChannelData(
×
3347
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3348
        )
×
3349
        if err != nil {
×
3350
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3351
        }
×
3352

3353
        // Call the call-back for each channel and its known policies.
3354
        for _, row := range rows {
×
3355
                node1, node2, err := buildNodeVertices(
×
3356
                        row.Node1Pubkey, row.Node2Pubkey,
×
3357
                )
×
3358
                if err != nil {
×
3359
                        return fmt.Errorf("unable to build node vertices: %w",
×
3360
                                err)
×
3361
                }
×
3362

3363
                edge, err := buildEdgeInfoWithBatchData(
×
3364
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3365
                        batchData,
×
3366
                )
×
3367
                if err != nil {
×
3368
                        return fmt.Errorf("unable to build channel info: %w",
×
3369
                                err)
×
3370
                }
×
3371

3372
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3373
                if err != nil {
×
3374
                        return fmt.Errorf("unable to extract channel "+
×
3375
                                "policies: %w", err)
×
3376
                }
×
3377

3378
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3379
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3380
                )
×
3381
                if err != nil {
×
3382
                        return fmt.Errorf("unable to build channel "+
×
3383
                                "policies: %w", err)
×
3384
                }
×
3385

3386
                // Determine the outgoing and incoming policy for this
3387
                // channel and node combo.
3388
                p1ToNode := row.GraphChannel.NodeID2
×
3389
                p2ToNode := row.GraphChannel.NodeID1
×
3390
                outPolicy, inPolicy := p1, p2
×
3391
                if (p1 != nil && p1ToNode == id) ||
×
3392
                        (p2 != nil && p2ToNode != id) {
×
3393

×
3394
                        outPolicy, inPolicy = p2, p1
×
3395
                }
×
3396

3397
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3398
                        return err
×
3399
                }
×
3400
        }
3401

3402
        return nil
×
3403
}
3404

3405
// updateChanEdgePolicy upserts the channel policy info we have stored for
3406
// a channel we already know of.
3407
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3408
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3409
        error) {
×
3410

×
3411
        var (
×
3412
                node1Pub, node2Pub route.Vertex
×
3413
                isNode1            bool
×
3414
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3415
        )
×
3416

×
3417
        // Check that this edge policy refers to a channel that we already
×
3418
        // know of. We do this explicitly so that we can return the appropriate
×
3419
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3420
        // abort the transaction which would abort the entire batch.
×
3421
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3422
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3423
                        Scid:    chanIDB,
×
3424
                        Version: int16(lnwire.GossipVersion1),
×
3425
                },
×
3426
        )
×
3427
        if errors.Is(err, sql.ErrNoRows) {
×
3428
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3429
        } else if err != nil {
×
3430
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3431
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3432
        }
×
3433

3434
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3435
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3436

×
3437
        // Figure out which node this edge is from.
×
3438
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3439
        nodeID := dbChan.NodeID1
×
3440
        if !isNode1 {
×
3441
                nodeID = dbChan.NodeID2
×
3442
        }
×
3443

3444
        var (
×
3445
                inboundBase sql.NullInt64
×
3446
                inboundRate sql.NullInt64
×
3447
        )
×
3448
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3449
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3450
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3451
        })
×
3452

3453
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3454
                Version:     int16(lnwire.GossipVersion1),
×
3455
                ChannelID:   dbChan.ID,
×
3456
                NodeID:      nodeID,
×
3457
                Timelock:    int32(edge.TimeLockDelta),
×
3458
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3459
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3460
                MinHtlcMsat: int64(edge.MinHTLC),
×
3461
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3462
                Disabled: sql.NullBool{
×
3463
                        Valid: true,
×
3464
                        Bool:  edge.IsDisabled(),
×
3465
                },
×
3466
                MaxHtlcMsat: sql.NullInt64{
×
3467
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3468
                        Int64: int64(edge.MaxHTLC),
×
3469
                },
×
3470
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3471
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3472
                InboundBaseFeeMsat:      inboundBase,
×
3473
                InboundFeeRateMilliMsat: inboundRate,
×
3474
                Signature:               edge.SigBytes,
×
3475
        })
×
3476
        if err != nil {
×
3477
                return node1Pub, node2Pub, isNode1,
×
3478
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3479
        }
×
3480

3481
        // Convert the flat extra opaque data into a map of TLV types to
3482
        // values.
3483
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3484
        if err != nil {
×
3485
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3486
                        "marshal extra opaque data: %w", err)
×
3487
        }
×
3488

3489
        // Update the channel policy's extra signed fields.
3490
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3491
        if err != nil {
×
3492
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3493
                        "policy extra TLVs: %w", err)
×
3494
        }
×
3495

3496
        return node1Pub, node2Pub, isNode1, nil
×
3497
}
3498

3499
// getNodeByPubKey attempts to look up a target node by its public key.
3500
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3501
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3502

×
3503
        dbNode, err := db.GetNodeByPubKey(
×
3504
                ctx, sqlc.GetNodeByPubKeyParams{
×
3505
                        Version: int16(lnwire.GossipVersion1),
×
3506
                        PubKey:  pubKey[:],
×
3507
                },
×
3508
        )
×
3509
        if errors.Is(err, sql.ErrNoRows) {
×
3510
                return 0, nil, ErrGraphNodeNotFound
×
3511
        } else if err != nil {
×
3512
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3513
        }
×
3514

3515
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3516
        if err != nil {
×
3517
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3518
        }
×
3519

3520
        return dbNode.ID, node, nil
×
3521
}
3522

3523
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3524
// provided parameters.
3525
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3526
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3527

×
3528
        return &models.CachedEdgeInfo{
×
3529
                ChannelID:     byteOrder.Uint64(scid),
×
3530
                NodeKey1Bytes: node1Pub,
×
3531
                NodeKey2Bytes: node2Pub,
×
3532
                Capacity:      btcutil.Amount(capacity),
×
3533
        }
×
3534
}
×
3535

3536
// buildNode constructs a Node instance from the given database node
3537
// record. The node's features, addresses and extra signed fields are also
3538
// fetched from the database and set on the node.
3539
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3540
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3541

×
3542
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3543
        if err != nil {
×
3544
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3545
                        err)
×
3546
        }
×
3547

3548
        return buildNodeWithBatchData(dbNode, data)
×
3549
}
3550

3551
// buildNodeWithBatchData builds a models.Node instance
3552
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3553
// features/addresses/extra fields, then the corresponding fields are expected
3554
// to be present in the batchNodeData.
3555
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3556
        batchData *batchNodeData) (*models.Node, error) {
×
3557

×
3558
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
3559
                return nil, fmt.Errorf("unsupported node version: %d",
×
3560
                        dbNode.Version)
×
3561
        }
×
3562

3563
        var pub [33]byte
×
3564
        copy(pub[:], dbNode.PubKey)
×
3565

×
3566
        node := models.NewV1ShellNode(pub)
×
3567

×
3568
        if len(dbNode.Signature) == 0 {
×
3569
                return node, nil
×
3570
        }
×
3571

3572
        node.AuthSigBytes = dbNode.Signature
×
3573

×
3574
        if dbNode.Alias.Valid {
×
3575
                node.Alias = fn.Some(dbNode.Alias.String)
×
3576
        }
×
3577
        if dbNode.LastUpdate.Valid {
×
3578
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3579
        }
×
3580

3581
        var err error
×
3582
        if dbNode.Color.Valid {
×
3583
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3584
                if err != nil {
×
3585
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3586
                                err)
×
3587
                }
×
3588

3589
                node.Color = fn.Some(nodeColor)
×
3590
        }
3591

3592
        // Use preloaded features.
3593
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3594
                fv := lnwire.EmptyFeatureVector()
×
3595
                for _, bit := range features {
×
3596
                        fv.Set(lnwire.FeatureBit(bit))
×
3597
                }
×
3598
                node.Features = fv
×
3599
        }
3600

3601
        // Use preloaded addresses.
3602
        addresses, exists := batchData.addresses[dbNode.ID]
×
3603
        if exists && len(addresses) > 0 {
×
3604
                node.Addresses, err = buildNodeAddresses(addresses)
×
3605
                if err != nil {
×
3606
                        return nil, fmt.Errorf("unable to build addresses "+
×
3607
                                "for node(%d): %w", dbNode.ID, err)
×
3608
                }
×
3609
        }
3610

3611
        // Use preloaded extra fields.
3612
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3613
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3614
                if err != nil {
×
3615
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3616
                                "signed fields: %w", err)
×
3617
                }
×
3618
                if len(recs) != 0 {
×
3619
                        node.ExtraOpaqueData = recs
×
3620
                }
×
3621
        }
3622

3623
        return node, nil
×
3624
}
3625

3626
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3627
// with the preloaded data, and executes the provided callback for each node.
3628
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3629
        db SQLQueries, nodes []sqlc.GraphNode,
3630
        cb func(dbID int64, node *models.Node) error) error {
×
3631

×
3632
        // Extract node IDs for batch loading.
×
3633
        nodeIDs := make([]int64, len(nodes))
×
3634
        for i, node := range nodes {
×
3635
                nodeIDs[i] = node.ID
×
3636
        }
×
3637

3638
        // Batch load all related data for this page.
3639
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3640
        if err != nil {
×
3641
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3642
        }
×
3643

3644
        for _, dbNode := range nodes {
×
3645
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3646
                if err != nil {
×
3647
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3648
                                dbNode.ID, err)
×
3649
                }
×
3650

3651
                if err := cb(dbNode.ID, node); err != nil {
×
3652
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3653
                                dbNode.ID, err)
×
3654
                }
×
3655
        }
3656

3657
        return nil
×
3658
}
3659

3660
// getNodeFeatures fetches the feature bits and constructs the feature vector
3661
// for a node with the given DB ID.
3662
func getNodeFeatures(ctx context.Context, db SQLQueries,
3663
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3664

×
3665
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3666
        if err != nil {
×
3667
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3668
                        nodeID, err)
×
3669
        }
×
3670

3671
        features := lnwire.EmptyFeatureVector()
×
3672
        for _, feature := range rows {
×
3673
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3674
        }
×
3675

3676
        return features, nil
×
3677
}
3678

3679
// upsertNodeAncillaryData updates the node's features, addresses, and extra
3680
// signed fields. This is common logic shared by upsertNode and
3681
// upsertSourceNode.
3682
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
3683
        nodeID int64, node *models.Node) error {
×
3684

×
3685
        // Update the node's features.
×
3686
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3687
        if err != nil {
×
3688
                return fmt.Errorf("inserting node features: %w", err)
×
3689
        }
×
3690

3691
        // Update the node's addresses.
3692
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3693
        if err != nil {
×
3694
                return fmt.Errorf("inserting node addresses: %w", err)
×
3695
        }
×
3696

3697
        // Convert the flat extra opaque data into a map of TLV types to
3698
        // values.
3699
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3700
        if err != nil {
×
3701
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3702
                        err)
×
3703
        }
×
3704

3705
        // Update the node's extra signed fields.
3706
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3707
        if err != nil {
×
3708
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
3709
        }
×
3710

3711
        return nil
×
3712
}
3713

3714
// populateNodeParams populates the common node parameters from a models.Node.
3715
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
3716
func populateNodeParams(node *models.Node,
3717
        setParams func(lastUpdate sql.NullInt64, alias,
3718
                colorStr sql.NullString, signature []byte)) error {
×
3719

×
3720
        if !node.HaveAnnouncement() {
×
3721
                return nil
×
3722
        }
×
3723

3724
        switch node.Version {
×
3725
        case lnwire.GossipVersion1:
×
3726
                lastUpdate := sqldb.SQLInt64(node.LastUpdate.Unix())
×
3727
                var alias, colorStr sql.NullString
×
3728

×
3729
                node.Color.WhenSome(func(rgba color.RGBA) {
×
3730
                        colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
3731
                })
×
3732
                node.Alias.WhenSome(func(s string) {
×
3733
                        alias = sqldb.SQLStrValid(s)
×
3734
                })
×
3735

3736
                setParams(lastUpdate, alias, colorStr, node.AuthSigBytes)
×
3737

3738
        case lnwire.GossipVersion2:
×
3739
                // No-op for now.
3740

3741
        default:
×
3742
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
3743
        }
3744

3745
        return nil
×
3746
}
3747

3748
// buildNodeUpsertParams builds the parameters for upserting a node using the
3749
// strict UpsertNode query (requires timestamp to be increasing).
3750
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
3751
        params := sqlc.UpsertNodeParams{
×
3752
                Version: int16(lnwire.GossipVersion1),
×
3753
                PubKey:  node.PubKeyBytes[:],
×
3754
        }
×
3755

×
3756
        err := populateNodeParams(
×
3757
                node, func(lastUpdate sql.NullInt64, alias,
×
3758
                        colorStr sql.NullString,
×
3759
                        signature []byte) {
×
3760

×
3761
                        params.LastUpdate = lastUpdate
×
3762
                        params.Alias = alias
×
3763
                        params.Color = colorStr
×
3764
                        params.Signature = signature
×
3765
                })
×
3766

3767
        return params, err
×
3768
}
3769

3770
// buildSourceNodeUpsertParams builds the parameters for upserting the source
3771
// node using the lenient UpsertSourceNode query (allows same timestamp).
3772
func buildSourceNodeUpsertParams(node *models.Node) (
3773
        sqlc.UpsertSourceNodeParams, error) {
×
3774

×
3775
        params := sqlc.UpsertSourceNodeParams{
×
3776
                Version: int16(lnwire.GossipVersion1),
×
3777
                PubKey:  node.PubKeyBytes[:],
×
3778
        }
×
3779

×
3780
        err := populateNodeParams(
×
3781
                node, func(lastUpdate sql.NullInt64, alias,
×
3782
                        colorStr sql.NullString, signature []byte) {
×
3783

×
3784
                        params.LastUpdate = lastUpdate
×
3785
                        params.Alias = alias
×
3786
                        params.Color = colorStr
×
3787
                        params.Signature = signature
×
3788
                },
×
3789
        )
3790

3791
        return params, err
×
3792
}
3793

3794
// upsertSourceNode upserts the source node record into the database using a
3795
// less strict upsert that allows updates even when the timestamp hasn't
3796
// changed. This is necessary to handle concurrent updates to our own node
3797
// during startup and runtime. The node's features, addresses and extra TLV
3798
// types are also updated. The node's DB ID is returned.
3799
func upsertSourceNode(ctx context.Context, db SQLQueries,
3800
        node *models.Node) (int64, error) {
×
3801

×
3802
        params, err := buildSourceNodeUpsertParams(node)
×
3803
        if err != nil {
×
3804
                return 0, err
×
3805
        }
×
3806

3807
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
3808
        if err != nil {
×
3809
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
3810
                        node.PubKeyBytes, err)
×
3811
        }
×
3812

3813
        // We can exit here if we don't have the announcement yet.
3814
        if !node.HaveAnnouncement() {
×
3815
                return nodeID, nil
×
3816
        }
×
3817

3818
        // Update the ancillary node data (features, addresses, extra fields).
3819
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3820
        if err != nil {
×
3821
                return 0, err
×
3822
        }
×
3823

3824
        return nodeID, nil
×
3825
}
3826

3827
// upsertNode upserts the node record into the database. If the node already
3828
// exists, then the node's information is updated. If the node doesn't exist,
3829
// then a new node is created. The node's features, addresses and extra TLV
3830
// types are also updated. The node's DB ID is returned.
3831
func upsertNode(ctx context.Context, db SQLQueries,
3832
        node *models.Node) (int64, error) {
×
3833

×
3834
        params, err := buildNodeUpsertParams(node)
×
3835
        if err != nil {
×
3836
                return 0, err
×
3837
        }
×
3838

3839
        nodeID, err := db.UpsertNode(ctx, params)
×
3840
        if err != nil {
×
3841
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3842
                        err)
×
3843
        }
×
3844

3845
        // We can exit here if we don't have the announcement yet.
3846
        if !node.HaveAnnouncement() {
×
3847
                return nodeID, nil
×
3848
        }
×
3849

3850
        // Update the ancillary node data (features, addresses, extra fields).
3851
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3852
        if err != nil {
×
3853
                return 0, err
×
3854
        }
×
3855

3856
        return nodeID, nil
×
3857
}
3858

3859
// upsertNodeFeatures updates the node's features node_features table. This
3860
// includes deleting any feature bits no longer present and inserting any new
3861
// feature bits. If the feature bit does not yet exist in the features table,
3862
// then an entry is created in that table first.
3863
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3864
        features *lnwire.FeatureVector) error {
×
3865

×
3866
        // Get any existing features for the node.
×
3867
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3868
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3869
                return err
×
3870
        }
×
3871

3872
        // Copy the nodes latest set of feature bits.
3873
        newFeatures := make(map[int32]struct{})
×
3874
        if features != nil {
×
3875
                for feature := range features.Features() {
×
3876
                        newFeatures[int32(feature)] = struct{}{}
×
3877
                }
×
3878
        }
3879

3880
        // For any current feature that already exists in the DB, remove it from
3881
        // the in-memory map. For any existing feature that does not exist in
3882
        // the in-memory map, delete it from the database.
3883
        for _, feature := range existingFeatures {
×
3884
                // The feature is still present, so there are no updates to be
×
3885
                // made.
×
3886
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3887
                        delete(newFeatures, feature.FeatureBit)
×
3888
                        continue
×
3889
                }
3890

3891
                // The feature is no longer present, so we remove it from the
3892
                // database.
3893
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3894
                        NodeID:     nodeID,
×
3895
                        FeatureBit: feature.FeatureBit,
×
3896
                })
×
3897
                if err != nil {
×
3898
                        return fmt.Errorf("unable to delete node(%d) "+
×
3899
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3900
                                err)
×
3901
                }
×
3902
        }
3903

3904
        // Any remaining entries in newFeatures are new features that need to be
3905
        // added to the database for the first time.
3906
        for feature := range newFeatures {
×
3907
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3908
                        NodeID:     nodeID,
×
3909
                        FeatureBit: feature,
×
3910
                })
×
3911
                if err != nil {
×
3912
                        return fmt.Errorf("unable to insert node(%d) "+
×
3913
                                "feature(%v): %w", nodeID, feature, err)
×
3914
                }
×
3915
        }
3916

3917
        return nil
×
3918
}
3919

3920
// fetchNodeFeatures fetches the features for a node with the given public key.
3921
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3922
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3923

×
3924
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3925
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3926
                        PubKey:  nodePub[:],
×
3927
                        Version: int16(lnwire.GossipVersion1),
×
3928
                },
×
3929
        )
×
3930
        if err != nil {
×
3931
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3932
                        nodePub, err)
×
3933
        }
×
3934

3935
        features := lnwire.EmptyFeatureVector()
×
3936
        for _, bit := range rows {
×
3937
                features.Set(lnwire.FeatureBit(bit))
×
3938
        }
×
3939

3940
        return features, nil
×
3941
}
3942

3943
// dbAddressType is an enum type that represents the different address types
3944
// that we store in the node_addresses table. The address type determines how
3945
// the address is to be serialised/deserialize.
3946
type dbAddressType uint8
3947

3948
const (
3949
        addressTypeIPv4   dbAddressType = 1
3950
        addressTypeIPv6   dbAddressType = 2
3951
        addressTypeTorV2  dbAddressType = 3
3952
        addressTypeTorV3  dbAddressType = 4
3953
        addressTypeDNS    dbAddressType = 5
3954
        addressTypeOpaque dbAddressType = math.MaxInt8
3955
)
3956

3957
// collectAddressRecords collects the addresses from the provided
3958
// net.Addr slice and returns a map of dbAddressType to a slice of address
3959
// strings.
3960
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3961
        error) {
×
3962

×
3963
        // Copy the nodes latest set of addresses.
×
3964
        newAddresses := map[dbAddressType][]string{
×
3965
                addressTypeIPv4:   {},
×
3966
                addressTypeIPv6:   {},
×
3967
                addressTypeTorV2:  {},
×
3968
                addressTypeTorV3:  {},
×
3969
                addressTypeDNS:    {},
×
3970
                addressTypeOpaque: {},
×
3971
        }
×
3972
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3973
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3974
        }
×
3975

3976
        for _, address := range addresses {
×
3977
                switch addr := address.(type) {
×
3978
                case *net.TCPAddr:
×
3979
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3980
                                addAddr(addressTypeIPv4, addr)
×
3981
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3982
                                addAddr(addressTypeIPv6, addr)
×
3983
                        } else {
×
3984
                                return nil, fmt.Errorf("unhandled IP "+
×
3985
                                        "address: %v", addr)
×
3986
                        }
×
3987

3988
                case *tor.OnionAddr:
×
3989
                        switch len(addr.OnionService) {
×
3990
                        case tor.V2Len:
×
3991
                                addAddr(addressTypeTorV2, addr)
×
3992
                        case tor.V3Len:
×
3993
                                addAddr(addressTypeTorV3, addr)
×
3994
                        default:
×
3995
                                return nil, fmt.Errorf("invalid length for " +
×
3996
                                        "a tor address")
×
3997
                        }
3998

3999
                case *lnwire.DNSAddress:
×
4000
                        addAddr(addressTypeDNS, addr)
×
4001

4002
                case *lnwire.OpaqueAddrs:
×
4003
                        addAddr(addressTypeOpaque, addr)
×
4004

4005
                default:
×
4006
                        return nil, fmt.Errorf("unhandled address type: %T",
×
4007
                                addr)
×
4008
                }
4009
        }
4010

4011
        return newAddresses, nil
×
4012
}
4013

4014
// upsertNodeAddresses updates the node's addresses in the database. This
4015
// includes deleting any existing addresses and inserting the new set of
4016
// addresses. The deletion is necessary since the ordering of the addresses may
4017
// change, and we need to ensure that the database reflects the latest set of
4018
// addresses so that at the time of reconstructing the node announcement, the
4019
// order is preserved and the signature over the message remains valid.
4020
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
4021
        addresses []net.Addr) error {
×
4022

×
4023
        // Delete any existing addresses for the node. This is required since
×
4024
        // even if the new set of addresses is the same, the ordering may have
×
4025
        // changed for a given address type.
×
4026
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
4027
        if err != nil {
×
4028
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
4029
                        nodeID, err)
×
4030
        }
×
4031

4032
        newAddresses, err := collectAddressRecords(addresses)
×
4033
        if err != nil {
×
4034
                return err
×
4035
        }
×
4036

4037
        // Any remaining entries in newAddresses are new addresses that need to
4038
        // be added to the database for the first time.
4039
        for addrType, addrList := range newAddresses {
×
4040
                for position, addr := range addrList {
×
4041
                        err := db.UpsertNodeAddress(
×
4042
                                ctx, sqlc.UpsertNodeAddressParams{
×
4043
                                        NodeID:   nodeID,
×
4044
                                        Type:     int16(addrType),
×
4045
                                        Address:  addr,
×
4046
                                        Position: int32(position),
×
4047
                                },
×
4048
                        )
×
4049
                        if err != nil {
×
4050
                                return fmt.Errorf("unable to insert "+
×
4051
                                        "node(%d) address(%v): %w", nodeID,
×
4052
                                        addr, err)
×
4053
                        }
×
4054
                }
4055
        }
4056

4057
        return nil
×
4058
}
4059

4060
// getNodeAddresses fetches the addresses for a node with the given DB ID.
4061
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
4062
        error) {
×
4063

×
4064
        // GetNodeAddresses ensures that the addresses for a given type are
×
4065
        // returned in the same order as they were inserted.
×
4066
        rows, err := db.GetNodeAddresses(ctx, id)
×
4067
        if err != nil {
×
4068
                return nil, err
×
4069
        }
×
4070

4071
        addresses := make([]net.Addr, 0, len(rows))
×
4072
        for _, row := range rows {
×
4073
                address := row.Address
×
4074

×
4075
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4076
                if err != nil {
×
4077
                        return nil, fmt.Errorf("unable to parse address "+
×
4078
                                "for node(%d): %v: %w", id, address, err)
×
4079
                }
×
4080

4081
                addresses = append(addresses, addr)
×
4082
        }
4083

4084
        // If we have no addresses, then we'll return nil instead of an
4085
        // empty slice.
4086
        if len(addresses) == 0 {
×
4087
                addresses = nil
×
4088
        }
×
4089

4090
        return addresses, nil
×
4091
}
4092

4093
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4094
// database. This includes updating any existing types, inserting any new types,
4095
// and deleting any types that are no longer present.
4096
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4097
        nodeID int64, extraFields map[uint64][]byte) error {
×
4098

×
4099
        // Get any existing extra signed fields for the node.
×
4100
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4101
        if err != nil {
×
4102
                return err
×
4103
        }
×
4104

4105
        // Make a lookup map of the existing field types so that we can use it
4106
        // to keep track of any fields we should delete.
4107
        m := make(map[uint64]bool)
×
4108
        for _, field := range existingFields {
×
4109
                m[uint64(field.Type)] = true
×
4110
        }
×
4111

4112
        // For all the new fields, we'll upsert them and remove them from the
4113
        // map of existing fields.
4114
        for tlvType, value := range extraFields {
×
4115
                err = db.UpsertNodeExtraType(
×
4116
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4117
                                NodeID: nodeID,
×
4118
                                Type:   int64(tlvType),
×
4119
                                Value:  value,
×
4120
                        },
×
4121
                )
×
4122
                if err != nil {
×
4123
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4124
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4125
                }
×
4126

4127
                // Remove the field from the map of existing fields if it was
4128
                // present.
4129
                delete(m, tlvType)
×
4130
        }
4131

4132
        // For all the fields that are left in the map of existing fields, we'll
4133
        // delete them as they are no longer present in the new set of fields.
4134
        for tlvType := range m {
×
4135
                err = db.DeleteExtraNodeType(
×
4136
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4137
                                NodeID: nodeID,
×
4138
                                Type:   int64(tlvType),
×
4139
                        },
×
4140
                )
×
4141
                if err != nil {
×
4142
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4143
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4144
                }
×
4145
        }
4146

4147
        return nil
×
4148
}
4149

4150
// srcNodeInfo holds the information about the source node of the graph.
4151
type srcNodeInfo struct {
4152
        // id is the DB level ID of the source node entry in the "nodes" table.
4153
        id int64
4154

4155
        // pub is the public key of the source node.
4156
        pub route.Vertex
4157
}
4158

4159
// sourceNode returns the DB node ID and pub key of the source node for the
4160
// specified protocol version.
4161
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4162
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4163

×
4164
        s.srcNodeMu.Lock()
×
4165
        defer s.srcNodeMu.Unlock()
×
4166

×
4167
        // If we already have the source node ID and pub key cached, then
×
4168
        // return them.
×
4169
        if info, ok := s.srcNodes[version]; ok {
×
4170
                return info.id, info.pub, nil
×
4171
        }
×
4172

4173
        var pubKey route.Vertex
×
4174

×
4175
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4176
        if err != nil {
×
4177
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4178
                        err)
×
4179
        }
×
4180

4181
        if len(nodes) == 0 {
×
4182
                return 0, pubKey, ErrSourceNodeNotSet
×
4183
        } else if len(nodes) > 1 {
×
4184
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4185
                        "protocol %s found", version)
×
4186
        }
×
4187

4188
        copy(pubKey[:], nodes[0].PubKey)
×
4189

×
4190
        s.srcNodes[version] = &srcNodeInfo{
×
4191
                id:  nodes[0].NodeID,
×
4192
                pub: pubKey,
×
4193
        }
×
4194

×
4195
        return nodes[0].NodeID, pubKey, nil
×
4196
}
4197

4198
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4199
// This then produces a map from TLV type to value. If the input is not a
4200
// valid TLV stream, then an error is returned.
4201
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4202
        r := bytes.NewReader(data)
×
4203

×
4204
        tlvStream, err := tlv.NewStream()
×
4205
        if err != nil {
×
4206
                return nil, err
×
4207
        }
×
4208

4209
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4210
        // pass it into the P2P decoding variant.
4211
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4212
        if err != nil {
×
4213
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4214
        }
×
4215
        if len(parsedTypes) == 0 {
×
4216
                return nil, nil
×
4217
        }
×
4218

4219
        records := make(map[uint64][]byte)
×
4220
        for k, v := range parsedTypes {
×
4221
                records[uint64(k)] = v
×
4222
        }
×
4223

4224
        return records, nil
×
4225
}
4226

4227
// insertChannel inserts a new channel record into the database.
4228
func insertChannel(ctx context.Context, db SQLQueries,
4229
        edge *models.ChannelEdgeInfo) error {
×
4230

×
4231
        // Make sure that at least a "shell" entry for each node is present in
×
4232
        // the nodes table.
×
4233
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4234
        if err != nil {
×
4235
                return fmt.Errorf("unable to create shell node: %w", err)
×
4236
        }
×
4237

4238
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4239
        if err != nil {
×
4240
                return fmt.Errorf("unable to create shell node: %w", err)
×
4241
        }
×
4242

4243
        var capacity sql.NullInt64
×
4244
        if edge.Capacity != 0 {
×
4245
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4246
        }
×
4247

4248
        createParams := sqlc.CreateChannelParams{
×
4249
                Version:     int16(lnwire.GossipVersion1),
×
4250
                Scid:        channelIDToBytes(edge.ChannelID),
×
4251
                NodeID1:     node1DBID,
×
4252
                NodeID2:     node2DBID,
×
4253
                Outpoint:    edge.ChannelPoint.String(),
×
4254
                Capacity:    capacity,
×
4255
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4256
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4257
        }
×
4258

×
4259
        if edge.AuthProof != nil {
×
4260
                proof := edge.AuthProof
×
4261

×
4262
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4263
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4264
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4265
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4266
        }
×
4267

4268
        // Insert the new channel record.
4269
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4270
        if err != nil {
×
4271
                return err
×
4272
        }
×
4273

4274
        // Insert any channel features.
4275
        for feature := range edge.Features.Features() {
×
4276
                err = db.InsertChannelFeature(
×
4277
                        ctx, sqlc.InsertChannelFeatureParams{
×
4278
                                ChannelID:  dbChanID,
×
4279
                                FeatureBit: int32(feature),
×
4280
                        },
×
4281
                )
×
4282
                if err != nil {
×
4283
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4284
                                "feature(%v): %w", dbChanID, feature, err)
×
4285
                }
×
4286
        }
4287

4288
        // Finally, insert any extra TLV fields in the channel announcement.
4289
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4290
        if err != nil {
×
4291
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4292
                        err)
×
4293
        }
×
4294

4295
        for tlvType, value := range extra {
×
4296
                err := db.UpsertChannelExtraType(
×
4297
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4298
                                ChannelID: dbChanID,
×
4299
                                Type:      int64(tlvType),
×
4300
                                Value:     value,
×
4301
                        },
×
4302
                )
×
4303
                if err != nil {
×
4304
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4305
                                "extra signed field(%v): %w", edge.ChannelID,
×
4306
                                tlvType, err)
×
4307
                }
×
4308
        }
4309

4310
        return nil
×
4311
}
4312

4313
// maybeCreateShellNode checks if a shell node entry exists for the
4314
// given public key. If it does not exist, then a new shell node entry is
4315
// created. The ID of the node is returned. A shell node only has a protocol
4316
// version and public key persisted.
4317
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4318
        pubKey route.Vertex) (int64, error) {
×
4319

×
4320
        dbNode, err := db.GetNodeByPubKey(
×
4321
                ctx, sqlc.GetNodeByPubKeyParams{
×
4322
                        PubKey:  pubKey[:],
×
4323
                        Version: int16(lnwire.GossipVersion1),
×
4324
                },
×
4325
        )
×
4326
        // The node exists. Return the ID.
×
4327
        if err == nil {
×
4328
                return dbNode.ID, nil
×
4329
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4330
                return 0, err
×
4331
        }
×
4332

4333
        // Otherwise, the node does not exist, so we create a shell entry for
4334
        // it.
4335
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4336
                Version: int16(lnwire.GossipVersion1),
×
4337
                PubKey:  pubKey[:],
×
4338
        })
×
4339
        if err != nil {
×
4340
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4341
        }
×
4342

4343
        return id, nil
×
4344
}
4345

4346
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4347
// the database. This includes deleting any existing types and then inserting
4348
// the new types.
4349
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4350
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4351

×
4352
        // Delete all existing extra signed fields for the channel policy.
×
4353
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4354
        if err != nil {
×
4355
                return fmt.Errorf("unable to delete "+
×
4356
                        "existing policy extra signed fields for policy %d: %w",
×
4357
                        chanPolicyID, err)
×
4358
        }
×
4359

4360
        // Insert all new extra signed fields for the channel policy.
4361
        for tlvType, value := range extraFields {
×
4362
                err = db.UpsertChanPolicyExtraType(
×
4363
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4364
                                ChannelPolicyID: chanPolicyID,
×
4365
                                Type:            int64(tlvType),
×
4366
                                Value:           value,
×
4367
                        },
×
4368
                )
×
4369
                if err != nil {
×
4370
                        return fmt.Errorf("unable to insert "+
×
4371
                                "channel_policy(%d) extra signed field(%v): %w",
×
4372
                                chanPolicyID, tlvType, err)
×
4373
                }
×
4374
        }
4375

4376
        return nil
×
4377
}
4378

4379
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4380
// provided dbChanRow and also fetches any other required information
4381
// to construct the edge info.
4382
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4383
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4384
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4385

×
4386
        data, err := batchLoadChannelData(
×
4387
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4388
        )
×
4389
        if err != nil {
×
4390
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4391
                        err)
×
4392
        }
×
4393

4394
        return buildEdgeInfoWithBatchData(
×
4395
                cfg.ChainHash, dbChan, node1, node2, data,
×
4396
        )
×
4397
}
4398

4399
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4400
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4401
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4402
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4403

×
4404
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4405
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4406
                        dbChan.Version)
×
4407
        }
×
4408

4409
        // Use pre-loaded features and extras types.
4410
        fv := lnwire.EmptyFeatureVector()
×
4411
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4412
                for _, bit := range features {
×
4413
                        fv.Set(lnwire.FeatureBit(bit))
×
4414
                }
×
4415
        }
4416

4417
        var extras map[uint64][]byte
×
4418
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4419
        if exists {
×
4420
                extras = channelExtras
×
4421
        } else {
×
4422
                extras = make(map[uint64][]byte)
×
4423
        }
×
4424

4425
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4426
        if err != nil {
×
4427
                return nil, err
×
4428
        }
×
4429

4430
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4431
        if err != nil {
×
4432
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4433
                        "fields: %w", err)
×
4434
        }
×
4435
        if recs == nil {
×
4436
                recs = make([]byte, 0)
×
4437
        }
×
4438

4439
        var btcKey1, btcKey2 route.Vertex
×
4440
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4441
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4442

×
4443
        channel := &models.ChannelEdgeInfo{
×
4444
                ChainHash:        chain,
×
4445
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4446
                NodeKey1Bytes:    node1,
×
4447
                NodeKey2Bytes:    node2,
×
4448
                BitcoinKey1Bytes: btcKey1,
×
4449
                BitcoinKey2Bytes: btcKey2,
×
4450
                ChannelPoint:     *op,
×
4451
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4452
                Features:         fv,
×
4453
                ExtraOpaqueData:  recs,
×
4454
        }
×
4455

×
4456
        // We always set all the signatures at the same time, so we can
×
4457
        // safely check if one signature is present to determine if we have the
×
4458
        // rest of the signatures for the auth proof.
×
4459
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4460
                channel.AuthProof = &models.ChannelAuthProof{
×
4461
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4462
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4463
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4464
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4465
                }
×
4466
        }
×
4467

4468
        return channel, nil
×
4469
}
4470

4471
// buildNodeVertices is a helper that converts raw node public keys
4472
// into route.Vertex instances.
4473
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4474
        route.Vertex, error) {
×
4475

×
4476
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4477
        if err != nil {
×
4478
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4479
                        "create vertex from node1 pubkey: %w", err)
×
4480
        }
×
4481

4482
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4483
        if err != nil {
×
4484
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4485
                        "create vertex from node2 pubkey: %w", err)
×
4486
        }
×
4487

4488
        return node1Vertex, node2Vertex, nil
×
4489
}
4490

4491
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4492
// retrieves all the extra info required to build the complete
4493
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4494
// the provided sqlc.GraphChannelPolicy records are nil.
4495
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4496
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4497
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4498
        *models.ChannelEdgePolicy, error) {
×
4499

×
4500
        if dbPol1 == nil && dbPol2 == nil {
×
4501
                return nil, nil, nil
×
4502
        }
×
4503

4504
        var policyIDs = make([]int64, 0, 2)
×
4505
        if dbPol1 != nil {
×
4506
                policyIDs = append(policyIDs, dbPol1.ID)
×
4507
        }
×
4508
        if dbPol2 != nil {
×
4509
                policyIDs = append(policyIDs, dbPol2.ID)
×
4510
        }
×
4511

4512
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4513
        if err != nil {
×
4514
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4515
                        "data: %w", err)
×
4516
        }
×
4517

4518
        pol1, err := buildChanPolicyWithBatchData(
×
4519
                dbPol1, channelID, node2, batchData,
×
4520
        )
×
4521
        if err != nil {
×
4522
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4523
        }
×
4524

4525
        pol2, err := buildChanPolicyWithBatchData(
×
4526
                dbPol2, channelID, node1, batchData,
×
4527
        )
×
4528
        if err != nil {
×
4529
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4530
        }
×
4531

4532
        return pol1, pol2, nil
×
4533
}
4534

4535
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4536
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4537
// then nil is returned for it.
4538
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4539
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4540
        *models.CachedEdgePolicy, error) {
×
4541

×
4542
        var p1, p2 *models.CachedEdgePolicy
×
4543
        if dbPol1 != nil {
×
4544
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4545
                if err != nil {
×
4546
                        return nil, nil, err
×
4547
                }
×
4548

4549
                p1 = models.NewCachedPolicy(policy1)
×
4550
        }
4551
        if dbPol2 != nil {
×
4552
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4553
                if err != nil {
×
4554
                        return nil, nil, err
×
4555
                }
×
4556

4557
                p2 = models.NewCachedPolicy(policy2)
×
4558
        }
4559

4560
        return p1, p2, nil
×
4561
}
4562

4563
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4564
// provided sqlc.GraphChannelPolicy and other required information.
4565
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4566
        extras map[uint64][]byte,
4567
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4568

×
4569
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4570
        if err != nil {
×
4571
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4572
                        "fields: %w", err)
×
4573
        }
×
4574

4575
        var inboundFee fn.Option[lnwire.Fee]
×
4576
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4577
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4578

×
4579
                inboundFee = fn.Some(lnwire.Fee{
×
4580
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4581
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4582
                })
×
4583
        }
×
4584

4585
        return &models.ChannelEdgePolicy{
×
4586
                SigBytes:  dbPolicy.Signature,
×
4587
                ChannelID: channelID,
×
4588
                LastUpdate: time.Unix(
×
4589
                        dbPolicy.LastUpdate.Int64, 0,
×
4590
                ),
×
4591
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4592
                        dbPolicy.MessageFlags,
×
4593
                ),
×
4594
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4595
                        dbPolicy.ChannelFlags,
×
4596
                ),
×
4597
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4598
                MinHTLC: lnwire.MilliSatoshi(
×
4599
                        dbPolicy.MinHtlcMsat,
×
4600
                ),
×
4601
                MaxHTLC: lnwire.MilliSatoshi(
×
4602
                        dbPolicy.MaxHtlcMsat.Int64,
×
4603
                ),
×
4604
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4605
                        dbPolicy.BaseFeeMsat,
×
4606
                ),
×
4607
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4608
                ToNode:                    toNode,
×
4609
                InboundFee:                inboundFee,
×
4610
                ExtraOpaqueData:           recs,
×
4611
        }, nil
×
4612
}
4613

4614
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4615
// row which is expected to be a sqlc type that contains channel policy
4616
// information. It returns two policies, which may be nil if the policy
4617
// information is not present in the row.
4618
//
4619
//nolint:ll,dupl,funlen
4620
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4621
        *sqlc.GraphChannelPolicy, error) {
×
4622

×
4623
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4624
        switch r := row.(type) {
×
4625
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4626
                if r.Policy1Timelock.Valid {
×
4627
                        policy1 = &sqlc.GraphChannelPolicy{
×
4628
                                Timelock:                r.Policy1Timelock.Int32,
×
4629
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4630
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4631
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4632
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4633
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4634
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4635
                                Disabled:                r.Policy1Disabled,
×
4636
                                MessageFlags:            r.Policy1MessageFlags,
×
4637
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4638
                        }
×
4639
                }
×
4640
                if r.Policy2Timelock.Valid {
×
4641
                        policy2 = &sqlc.GraphChannelPolicy{
×
4642
                                Timelock:                r.Policy2Timelock.Int32,
×
4643
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4644
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4645
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4646
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4647
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4648
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4649
                                Disabled:                r.Policy2Disabled,
×
4650
                                MessageFlags:            r.Policy2MessageFlags,
×
4651
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4652
                        }
×
4653
                }
×
4654

4655
                return policy1, policy2, nil
×
4656

4657
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4658
                if r.Policy1ID.Valid {
×
4659
                        policy1 = &sqlc.GraphChannelPolicy{
×
4660
                                ID:                      r.Policy1ID.Int64,
×
4661
                                Version:                 r.Policy1Version.Int16,
×
4662
                                ChannelID:               r.GraphChannel.ID,
×
4663
                                NodeID:                  r.Policy1NodeID.Int64,
×
4664
                                Timelock:                r.Policy1Timelock.Int32,
×
4665
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4666
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4667
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4668
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4669
                                LastUpdate:              r.Policy1LastUpdate,
×
4670
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4671
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4672
                                Disabled:                r.Policy1Disabled,
×
4673
                                MessageFlags:            r.Policy1MessageFlags,
×
4674
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4675
                                Signature:               r.Policy1Signature,
×
4676
                        }
×
4677
                }
×
4678
                if r.Policy2ID.Valid {
×
4679
                        policy2 = &sqlc.GraphChannelPolicy{
×
4680
                                ID:                      r.Policy2ID.Int64,
×
4681
                                Version:                 r.Policy2Version.Int16,
×
4682
                                ChannelID:               r.GraphChannel.ID,
×
4683
                                NodeID:                  r.Policy2NodeID.Int64,
×
4684
                                Timelock:                r.Policy2Timelock.Int32,
×
4685
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4686
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4687
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4688
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4689
                                LastUpdate:              r.Policy2LastUpdate,
×
4690
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4691
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4692
                                Disabled:                r.Policy2Disabled,
×
4693
                                MessageFlags:            r.Policy2MessageFlags,
×
4694
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4695
                                Signature:               r.Policy2Signature,
×
4696
                        }
×
4697
                }
×
4698

4699
                return policy1, policy2, nil
×
4700

4701
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4702
                if r.Policy1ID.Valid {
×
4703
                        policy1 = &sqlc.GraphChannelPolicy{
×
4704
                                ID:                      r.Policy1ID.Int64,
×
4705
                                Version:                 r.Policy1Version.Int16,
×
4706
                                ChannelID:               r.GraphChannel.ID,
×
4707
                                NodeID:                  r.Policy1NodeID.Int64,
×
4708
                                Timelock:                r.Policy1Timelock.Int32,
×
4709
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4710
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4711
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4712
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4713
                                LastUpdate:              r.Policy1LastUpdate,
×
4714
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4715
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4716
                                Disabled:                r.Policy1Disabled,
×
4717
                                MessageFlags:            r.Policy1MessageFlags,
×
4718
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4719
                                Signature:               r.Policy1Signature,
×
4720
                        }
×
4721
                }
×
4722
                if r.Policy2ID.Valid {
×
4723
                        policy2 = &sqlc.GraphChannelPolicy{
×
4724
                                ID:                      r.Policy2ID.Int64,
×
4725
                                Version:                 r.Policy2Version.Int16,
×
4726
                                ChannelID:               r.GraphChannel.ID,
×
4727
                                NodeID:                  r.Policy2NodeID.Int64,
×
4728
                                Timelock:                r.Policy2Timelock.Int32,
×
4729
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4730
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4731
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4732
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4733
                                LastUpdate:              r.Policy2LastUpdate,
×
4734
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4735
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4736
                                Disabled:                r.Policy2Disabled,
×
4737
                                MessageFlags:            r.Policy2MessageFlags,
×
4738
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4739
                                Signature:               r.Policy2Signature,
×
4740
                        }
×
4741
                }
×
4742

4743
                return policy1, policy2, nil
×
4744

4745
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4746
                if r.Policy1ID.Valid {
×
4747
                        policy1 = &sqlc.GraphChannelPolicy{
×
4748
                                ID:                      r.Policy1ID.Int64,
×
4749
                                Version:                 r.Policy1Version.Int16,
×
4750
                                ChannelID:               r.GraphChannel.ID,
×
4751
                                NodeID:                  r.Policy1NodeID.Int64,
×
4752
                                Timelock:                r.Policy1Timelock.Int32,
×
4753
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4754
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4755
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4756
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4757
                                LastUpdate:              r.Policy1LastUpdate,
×
4758
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4759
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4760
                                Disabled:                r.Policy1Disabled,
×
4761
                                MessageFlags:            r.Policy1MessageFlags,
×
4762
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4763
                                Signature:               r.Policy1Signature,
×
4764
                        }
×
4765
                }
×
4766
                if r.Policy2ID.Valid {
×
4767
                        policy2 = &sqlc.GraphChannelPolicy{
×
4768
                                ID:                      r.Policy2ID.Int64,
×
4769
                                Version:                 r.Policy2Version.Int16,
×
4770
                                ChannelID:               r.GraphChannel.ID,
×
4771
                                NodeID:                  r.Policy2NodeID.Int64,
×
4772
                                Timelock:                r.Policy2Timelock.Int32,
×
4773
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4774
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4775
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4776
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4777
                                LastUpdate:              r.Policy2LastUpdate,
×
4778
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4779
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4780
                                Disabled:                r.Policy2Disabled,
×
4781
                                MessageFlags:            r.Policy2MessageFlags,
×
4782
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4783
                                Signature:               r.Policy2Signature,
×
4784
                        }
×
4785
                }
×
4786

4787
                return policy1, policy2, nil
×
4788

4789
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4790
                if r.Policy1ID.Valid {
×
4791
                        policy1 = &sqlc.GraphChannelPolicy{
×
4792
                                ID:                      r.Policy1ID.Int64,
×
4793
                                Version:                 r.Policy1Version.Int16,
×
4794
                                ChannelID:               r.GraphChannel.ID,
×
4795
                                NodeID:                  r.Policy1NodeID.Int64,
×
4796
                                Timelock:                r.Policy1Timelock.Int32,
×
4797
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4798
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4799
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4800
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4801
                                LastUpdate:              r.Policy1LastUpdate,
×
4802
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4803
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4804
                                Disabled:                r.Policy1Disabled,
×
4805
                                MessageFlags:            r.Policy1MessageFlags,
×
4806
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4807
                                Signature:               r.Policy1Signature,
×
4808
                        }
×
4809
                }
×
4810
                if r.Policy2ID.Valid {
×
4811
                        policy2 = &sqlc.GraphChannelPolicy{
×
4812
                                ID:                      r.Policy2ID.Int64,
×
4813
                                Version:                 r.Policy2Version.Int16,
×
4814
                                ChannelID:               r.GraphChannel.ID,
×
4815
                                NodeID:                  r.Policy2NodeID.Int64,
×
4816
                                Timelock:                r.Policy2Timelock.Int32,
×
4817
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4818
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4819
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4820
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4821
                                LastUpdate:              r.Policy2LastUpdate,
×
4822
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4823
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4824
                                Disabled:                r.Policy2Disabled,
×
4825
                                MessageFlags:            r.Policy2MessageFlags,
×
4826
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4827
                                Signature:               r.Policy2Signature,
×
4828
                        }
×
4829
                }
×
4830

4831
                return policy1, policy2, nil
×
4832

4833
        case sqlc.ListChannelsForNodeIDsRow:
×
4834
                if r.Policy1ID.Valid {
×
4835
                        policy1 = &sqlc.GraphChannelPolicy{
×
4836
                                ID:                      r.Policy1ID.Int64,
×
4837
                                Version:                 r.Policy1Version.Int16,
×
4838
                                ChannelID:               r.GraphChannel.ID,
×
4839
                                NodeID:                  r.Policy1NodeID.Int64,
×
4840
                                Timelock:                r.Policy1Timelock.Int32,
×
4841
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4842
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4843
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4844
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4845
                                LastUpdate:              r.Policy1LastUpdate,
×
4846
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4847
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4848
                                Disabled:                r.Policy1Disabled,
×
4849
                                MessageFlags:            r.Policy1MessageFlags,
×
4850
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4851
                                Signature:               r.Policy1Signature,
×
4852
                        }
×
4853
                }
×
4854
                if r.Policy2ID.Valid {
×
4855
                        policy2 = &sqlc.GraphChannelPolicy{
×
4856
                                ID:                      r.Policy2ID.Int64,
×
4857
                                Version:                 r.Policy2Version.Int16,
×
4858
                                ChannelID:               r.GraphChannel.ID,
×
4859
                                NodeID:                  r.Policy2NodeID.Int64,
×
4860
                                Timelock:                r.Policy2Timelock.Int32,
×
4861
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4862
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4863
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4864
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4865
                                LastUpdate:              r.Policy2LastUpdate,
×
4866
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4867
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4868
                                Disabled:                r.Policy2Disabled,
×
4869
                                MessageFlags:            r.Policy2MessageFlags,
×
4870
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4871
                                Signature:               r.Policy2Signature,
×
4872
                        }
×
4873
                }
×
4874

4875
                return policy1, policy2, nil
×
4876

4877
        case sqlc.ListChannelsByNodeIDRow:
×
4878
                if r.Policy1ID.Valid {
×
4879
                        policy1 = &sqlc.GraphChannelPolicy{
×
4880
                                ID:                      r.Policy1ID.Int64,
×
4881
                                Version:                 r.Policy1Version.Int16,
×
4882
                                ChannelID:               r.GraphChannel.ID,
×
4883
                                NodeID:                  r.Policy1NodeID.Int64,
×
4884
                                Timelock:                r.Policy1Timelock.Int32,
×
4885
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4886
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4887
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4888
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4889
                                LastUpdate:              r.Policy1LastUpdate,
×
4890
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4891
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4892
                                Disabled:                r.Policy1Disabled,
×
4893
                                MessageFlags:            r.Policy1MessageFlags,
×
4894
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4895
                                Signature:               r.Policy1Signature,
×
4896
                        }
×
4897
                }
×
4898
                if r.Policy2ID.Valid {
×
4899
                        policy2 = &sqlc.GraphChannelPolicy{
×
4900
                                ID:                      r.Policy2ID.Int64,
×
4901
                                Version:                 r.Policy2Version.Int16,
×
4902
                                ChannelID:               r.GraphChannel.ID,
×
4903
                                NodeID:                  r.Policy2NodeID.Int64,
×
4904
                                Timelock:                r.Policy2Timelock.Int32,
×
4905
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4906
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4907
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4908
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4909
                                LastUpdate:              r.Policy2LastUpdate,
×
4910
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4911
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4912
                                Disabled:                r.Policy2Disabled,
×
4913
                                MessageFlags:            r.Policy2MessageFlags,
×
4914
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4915
                                Signature:               r.Policy2Signature,
×
4916
                        }
×
4917
                }
×
4918

4919
                return policy1, policy2, nil
×
4920

4921
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4922
                if r.Policy1ID.Valid {
×
4923
                        policy1 = &sqlc.GraphChannelPolicy{
×
4924
                                ID:                      r.Policy1ID.Int64,
×
4925
                                Version:                 r.Policy1Version.Int16,
×
4926
                                ChannelID:               r.GraphChannel.ID,
×
4927
                                NodeID:                  r.Policy1NodeID.Int64,
×
4928
                                Timelock:                r.Policy1Timelock.Int32,
×
4929
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4930
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4931
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4932
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4933
                                LastUpdate:              r.Policy1LastUpdate,
×
4934
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4935
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4936
                                Disabled:                r.Policy1Disabled,
×
4937
                                MessageFlags:            r.Policy1MessageFlags,
×
4938
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4939
                                Signature:               r.Policy1Signature,
×
4940
                        }
×
4941
                }
×
4942
                if r.Policy2ID.Valid {
×
4943
                        policy2 = &sqlc.GraphChannelPolicy{
×
4944
                                ID:                      r.Policy2ID.Int64,
×
4945
                                Version:                 r.Policy2Version.Int16,
×
4946
                                ChannelID:               r.GraphChannel.ID,
×
4947
                                NodeID:                  r.Policy2NodeID.Int64,
×
4948
                                Timelock:                r.Policy2Timelock.Int32,
×
4949
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4950
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4951
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4952
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4953
                                LastUpdate:              r.Policy2LastUpdate,
×
4954
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4955
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4956
                                Disabled:                r.Policy2Disabled,
×
4957
                                MessageFlags:            r.Policy2MessageFlags,
×
4958
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4959
                                Signature:               r.Policy2Signature,
×
4960
                        }
×
4961
                }
×
4962

4963
                return policy1, policy2, nil
×
4964

4965
        case sqlc.GetChannelsByIDsRow:
×
4966
                if r.Policy1ID.Valid {
×
4967
                        policy1 = &sqlc.GraphChannelPolicy{
×
4968
                                ID:                      r.Policy1ID.Int64,
×
4969
                                Version:                 r.Policy1Version.Int16,
×
4970
                                ChannelID:               r.GraphChannel.ID,
×
4971
                                NodeID:                  r.Policy1NodeID.Int64,
×
4972
                                Timelock:                r.Policy1Timelock.Int32,
×
4973
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4974
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4975
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4976
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4977
                                LastUpdate:              r.Policy1LastUpdate,
×
4978
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4979
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4980
                                Disabled:                r.Policy1Disabled,
×
4981
                                MessageFlags:            r.Policy1MessageFlags,
×
4982
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4983
                                Signature:               r.Policy1Signature,
×
4984
                        }
×
4985
                }
×
4986
                if r.Policy2ID.Valid {
×
4987
                        policy2 = &sqlc.GraphChannelPolicy{
×
4988
                                ID:                      r.Policy2ID.Int64,
×
4989
                                Version:                 r.Policy2Version.Int16,
×
4990
                                ChannelID:               r.GraphChannel.ID,
×
4991
                                NodeID:                  r.Policy2NodeID.Int64,
×
4992
                                Timelock:                r.Policy2Timelock.Int32,
×
4993
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4994
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4995
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4996
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4997
                                LastUpdate:              r.Policy2LastUpdate,
×
4998
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4999
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
5000
                                Disabled:                r.Policy2Disabled,
×
5001
                                MessageFlags:            r.Policy2MessageFlags,
×
5002
                                ChannelFlags:            r.Policy2ChannelFlags,
×
5003
                                Signature:               r.Policy2Signature,
×
5004
                        }
×
5005
                }
×
5006

5007
                return policy1, policy2, nil
×
5008

5009
        default:
×
5010
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
5011
                        "extractChannelPolicies: %T", r)
×
5012
        }
5013
}
5014

5015
// channelIDToBytes converts a channel ID (SCID) to a byte array
5016
// representation.
5017
func channelIDToBytes(channelID uint64) []byte {
×
5018
        var chanIDB [8]byte
×
5019
        byteOrder.PutUint64(chanIDB[:], channelID)
×
5020

×
5021
        return chanIDB[:]
×
5022
}
×
5023

5024
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
5025
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
5026
        if len(addresses) == 0 {
×
5027
                return nil, nil
×
5028
        }
×
5029

5030
        result := make([]net.Addr, 0, len(addresses))
×
5031
        for _, addr := range addresses {
×
5032
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
5033
                if err != nil {
×
5034
                        return nil, fmt.Errorf("unable to parse address %s "+
×
5035
                                "of type %d: %w", addr.address, addr.addrType,
×
5036
                                err)
×
5037
                }
×
5038
                if netAddr != nil {
×
5039
                        result = append(result, netAddr)
×
5040
                }
×
5041
        }
5042

5043
        // If we have no valid addresses, return nil instead of empty slice.
5044
        if len(result) == 0 {
×
5045
                return nil, nil
×
5046
        }
×
5047

5048
        return result, nil
×
5049
}
5050

5051
// parseAddress parses the given address string based on the address type
5052
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
5053
// and opaque addresses.
5054
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
5055
        switch addrType {
×
5056
        case addressTypeIPv4:
×
5057
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
5058
                if err != nil {
×
5059
                        return nil, err
×
5060
                }
×
5061

5062
                tcp.IP = tcp.IP.To4()
×
5063

×
5064
                return tcp, nil
×
5065

5066
        case addressTypeIPv6:
×
5067
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
5068
                if err != nil {
×
5069
                        return nil, err
×
5070
                }
×
5071

5072
                return tcp, nil
×
5073

5074
        case addressTypeTorV3, addressTypeTorV2:
×
5075
                service, portStr, err := net.SplitHostPort(address)
×
5076
                if err != nil {
×
5077
                        return nil, fmt.Errorf("unable to split tor "+
×
5078
                                "address: %v", address)
×
5079
                }
×
5080

5081
                port, err := strconv.Atoi(portStr)
×
5082
                if err != nil {
×
5083
                        return nil, err
×
5084
                }
×
5085

5086
                return &tor.OnionAddr{
×
5087
                        OnionService: service,
×
5088
                        Port:         port,
×
5089
                }, nil
×
5090

5091
        case addressTypeDNS:
×
5092
                hostname, portStr, err := net.SplitHostPort(address)
×
5093
                if err != nil {
×
5094
                        return nil, fmt.Errorf("unable to split DNS "+
×
5095
                                "address: %v", address)
×
5096
                }
×
5097

5098
                port, err := strconv.Atoi(portStr)
×
5099
                if err != nil {
×
5100
                        return nil, err
×
5101
                }
×
5102

5103
                return &lnwire.DNSAddress{
×
5104
                        Hostname: hostname,
×
5105
                        Port:     uint16(port),
×
5106
                }, nil
×
5107

5108
        case addressTypeOpaque:
×
5109
                opaque, err := hex.DecodeString(address)
×
5110
                if err != nil {
×
5111
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5112
                                "address: %v", address)
×
5113
                }
×
5114

5115
                return &lnwire.OpaqueAddrs{
×
5116
                        Payload: opaque,
×
5117
                }, nil
×
5118

5119
        default:
×
5120
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5121
        }
5122
}
5123

5124
// batchNodeData holds all the related data for a batch of nodes.
5125
type batchNodeData struct {
5126
        // features is a map from a DB node ID to the feature bits for that
5127
        // node.
5128
        features map[int64][]int
5129

5130
        // addresses is a map from a DB node ID to the node's addresses.
5131
        addresses map[int64][]nodeAddress
5132

5133
        // extraFields is a map from a DB node ID to the extra signed fields
5134
        // for that node.
5135
        extraFields map[int64]map[uint64][]byte
5136
}
5137

5138
// nodeAddress holds the address type, position and address string for a
5139
// node. This is used to batch the fetching of node addresses.
5140
type nodeAddress struct {
5141
        addrType dbAddressType
5142
        position int32
5143
        address  string
5144
}
5145

5146
// batchLoadNodeData loads all related data for a batch of node IDs using the
5147
// provided SQLQueries interface. It returns a batchNodeData instance containing
5148
// the node features, addresses and extra signed fields.
5149
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5150
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5151

×
5152
        // Batch load the node features.
×
5153
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5154
        if err != nil {
×
5155
                return nil, fmt.Errorf("unable to batch load node "+
×
5156
                        "features: %w", err)
×
5157
        }
×
5158

5159
        // Batch load the node addresses.
5160
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5161
        if err != nil {
×
5162
                return nil, fmt.Errorf("unable to batch load node "+
×
5163
                        "addresses: %w", err)
×
5164
        }
×
5165

5166
        // Batch load the node extra signed fields.
5167
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5168
        if err != nil {
×
5169
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5170
                        "signed fields: %w", err)
×
5171
        }
×
5172

5173
        return &batchNodeData{
×
5174
                features:    features,
×
5175
                addresses:   addrs,
×
5176
                extraFields: extraTypes,
×
5177
        }, nil
×
5178
}
5179

5180
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5181
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5182
func batchLoadNodeFeaturesHelper(ctx context.Context,
5183
        cfg *sqldb.QueryConfig, db SQLQueries,
5184
        nodeIDs []int64) (map[int64][]int, error) {
×
5185

×
5186
        features := make(map[int64][]int)
×
5187

×
5188
        return features, sqldb.ExecuteBatchQuery(
×
5189
                ctx, cfg, nodeIDs,
×
5190
                func(id int64) int64 {
×
5191
                        return id
×
5192
                },
×
5193
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5194
                        error) {
×
5195

×
5196
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5197
                },
×
5198
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5199
                        features[feature.NodeID] = append(
×
5200
                                features[feature.NodeID],
×
5201
                                int(feature.FeatureBit),
×
5202
                        )
×
5203

×
5204
                        return nil
×
5205
                },
×
5206
        )
5207
}
5208

5209
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5210
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5211
// node ID to a slice of nodeAddress structs.
5212
func batchLoadNodeAddressesHelper(ctx context.Context,
5213
        cfg *sqldb.QueryConfig, db SQLQueries,
5214
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5215

×
5216
        addrs := make(map[int64][]nodeAddress)
×
5217

×
5218
        return addrs, sqldb.ExecuteBatchQuery(
×
5219
                ctx, cfg, nodeIDs,
×
5220
                func(id int64) int64 {
×
5221
                        return id
×
5222
                },
×
5223
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5224
                        error) {
×
5225

×
5226
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5227
                },
×
5228
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5229
                        addrs[addr.NodeID] = append(
×
5230
                                addrs[addr.NodeID], nodeAddress{
×
5231
                                        addrType: dbAddressType(addr.Type),
×
5232
                                        position: addr.Position,
×
5233
                                        address:  addr.Address,
×
5234
                                },
×
5235
                        )
×
5236

×
5237
                        return nil
×
5238
                },
×
5239
        )
5240
}
5241

5242
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5243
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5244
// query.
5245
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5246
        cfg *sqldb.QueryConfig, db SQLQueries,
5247
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5248

×
5249
        extraFields := make(map[int64]map[uint64][]byte)
×
5250

×
5251
        callback := func(ctx context.Context,
×
5252
                field sqlc.GraphNodeExtraType) error {
×
5253

×
5254
                if extraFields[field.NodeID] == nil {
×
5255
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5256
                }
×
5257
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5258

×
5259
                return nil
×
5260
        }
5261

5262
        return extraFields, sqldb.ExecuteBatchQuery(
×
5263
                ctx, cfg, nodeIDs,
×
5264
                func(id int64) int64 {
×
5265
                        return id
×
5266
                },
×
5267
                func(ctx context.Context, ids []int64) (
5268
                        []sqlc.GraphNodeExtraType, error) {
×
5269

×
5270
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5271
                },
×
5272
                callback,
5273
        )
5274
}
5275

5276
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5277
// from the provided sqlc.GraphChannelPolicy records and the
5278
// provided batchChannelData.
5279
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5280
        channelID uint64, node1, node2 route.Vertex,
5281
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5282
        *models.ChannelEdgePolicy, error) {
×
5283

×
5284
        pol1, err := buildChanPolicyWithBatchData(
×
5285
                dbPol1, channelID, node2, batchData,
×
5286
        )
×
5287
        if err != nil {
×
5288
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5289
        }
×
5290

5291
        pol2, err := buildChanPolicyWithBatchData(
×
5292
                dbPol2, channelID, node1, batchData,
×
5293
        )
×
5294
        if err != nil {
×
5295
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5296
        }
×
5297

5298
        return pol1, pol2, nil
×
5299
}
5300

5301
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5302
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5303
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5304
        channelID uint64, toNode route.Vertex,
5305
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5306

×
5307
        if dbPol == nil {
×
5308
                return nil, nil
×
5309
        }
×
5310

5311
        var dbPol1Extras map[uint64][]byte
×
5312
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5313
                dbPol1Extras = extras
×
5314
        } else {
×
5315
                dbPol1Extras = make(map[uint64][]byte)
×
5316
        }
×
5317

5318
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5319
}
5320

5321
// batchChannelData holds all the related data for a batch of channels.
5322
type batchChannelData struct {
5323
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5324
        chanfeatures map[int64][]int
5325

5326
        // chanExtras is a map from DB channel ID to a map of TLV type to
5327
        // extra signed field bytes.
5328
        chanExtraTypes map[int64]map[uint64][]byte
5329

5330
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5331
        // to extra signed field bytes.
5332
        policyExtras map[int64]map[uint64][]byte
5333
}
5334

5335
// batchLoadChannelData loads all related data for batches of channels and
5336
// policies.
5337
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5338
        db SQLQueries, channelIDs []int64,
5339
        policyIDs []int64) (*batchChannelData, error) {
×
5340

×
5341
        batchData := &batchChannelData{
×
5342
                chanfeatures:   make(map[int64][]int),
×
5343
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5344
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5345
        }
×
5346

×
5347
        // Batch load channel features and extras
×
5348
        var err error
×
5349
        if len(channelIDs) > 0 {
×
5350
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5351
                        ctx, cfg, db, channelIDs,
×
5352
                )
×
5353
                if err != nil {
×
5354
                        return nil, fmt.Errorf("unable to batch load "+
×
5355
                                "channel features: %w", err)
×
5356
                }
×
5357

5358
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5359
                        ctx, cfg, db, channelIDs,
×
5360
                )
×
5361
                if err != nil {
×
5362
                        return nil, fmt.Errorf("unable to batch load "+
×
5363
                                "channel extras: %w", err)
×
5364
                }
×
5365
        }
5366

5367
        if len(policyIDs) > 0 {
×
5368
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5369
                        ctx, cfg, db, policyIDs,
×
5370
                )
×
5371
                if err != nil {
×
5372
                        return nil, fmt.Errorf("unable to batch load "+
×
5373
                                "policy extras: %w", err)
×
5374
                }
×
5375
                batchData.policyExtras = policyExtras
×
5376
        }
5377

5378
        return batchData, nil
×
5379
}
5380

5381
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5382
// channel IDs using ExecuteBatchQuery wrapper around the
5383
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5384
// slice of feature bits.
5385
func batchLoadChannelFeaturesHelper(ctx context.Context,
5386
        cfg *sqldb.QueryConfig, db SQLQueries,
5387
        channelIDs []int64) (map[int64][]int, error) {
×
5388

×
5389
        features := make(map[int64][]int)
×
5390

×
5391
        return features, sqldb.ExecuteBatchQuery(
×
5392
                ctx, cfg, channelIDs,
×
5393
                func(id int64) int64 {
×
5394
                        return id
×
5395
                },
×
5396
                func(ctx context.Context,
5397
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5398

×
5399
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5400
                },
×
5401
                func(ctx context.Context,
5402
                        feature sqlc.GraphChannelFeature) error {
×
5403

×
5404
                        features[feature.ChannelID] = append(
×
5405
                                features[feature.ChannelID],
×
5406
                                int(feature.FeatureBit),
×
5407
                        )
×
5408

×
5409
                        return nil
×
5410
                },
×
5411
        )
5412
}
5413

5414
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5415
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5416
// query. It returns a map from DB channel ID to a map of TLV type to extra
5417
// signed field bytes.
5418
func batchLoadChannelExtrasHelper(ctx context.Context,
5419
        cfg *sqldb.QueryConfig, db SQLQueries,
5420
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5421

×
5422
        extras := make(map[int64]map[uint64][]byte)
×
5423

×
5424
        cb := func(ctx context.Context,
×
5425
                extra sqlc.GraphChannelExtraType) error {
×
5426

×
5427
                if extras[extra.ChannelID] == nil {
×
5428
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5429
                }
×
5430
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5431

×
5432
                return nil
×
5433
        }
5434

5435
        return extras, sqldb.ExecuteBatchQuery(
×
5436
                ctx, cfg, channelIDs,
×
5437
                func(id int64) int64 {
×
5438
                        return id
×
5439
                },
×
5440
                func(ctx context.Context,
5441
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5442

×
5443
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5444
                }, cb,
×
5445
        )
5446
}
5447

5448
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5449
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5450
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5451
// a map of TLV type to extra signed field bytes.
5452
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5453
        cfg *sqldb.QueryConfig, db SQLQueries,
5454
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5455

×
5456
        extras := make(map[int64]map[uint64][]byte)
×
5457

×
5458
        return extras, sqldb.ExecuteBatchQuery(
×
5459
                ctx, cfg, policyIDs,
×
5460
                func(id int64) int64 {
×
5461
                        return id
×
5462
                },
×
5463
                func(ctx context.Context, ids []int64) (
5464
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5465

×
5466
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5467
                },
×
5468
                func(ctx context.Context,
5469
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5470

×
5471
                        if extras[row.PolicyID] == nil {
×
5472
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5473
                        }
×
5474
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5475

×
5476
                        return nil
×
5477
                },
5478
        )
5479
}
5480

5481
// forEachNodePaginated executes a paginated query to process each node in the
5482
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5483
// and applies the provided processNode function to each node.
5484
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5485
        db SQLQueries, protocol lnwire.GossipVersion,
5486
        processNode func(context.Context, int64,
5487
                *models.Node) error) error {
×
5488

×
5489
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5490
                limit int32) ([]sqlc.GraphNode, error) {
×
5491

×
5492
                return db.ListNodesPaginated(
×
5493
                        ctx, sqlc.ListNodesPaginatedParams{
×
5494
                                Version: int16(protocol),
×
5495
                                ID:      lastID,
×
5496
                                Limit:   limit,
×
5497
                        },
×
5498
                )
×
5499
        }
×
5500

5501
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5502
                return node.ID
×
5503
        }
×
5504

5505
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5506
                return node.ID, nil
×
5507
        }
×
5508

5509
        batchQueryFunc := func(ctx context.Context,
×
5510
                nodeIDs []int64) (*batchNodeData, error) {
×
5511

×
5512
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5513
        }
×
5514

5515
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5516
                batchData *batchNodeData) error {
×
5517

×
5518
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5519
                if err != nil {
×
5520
                        return fmt.Errorf("unable to build "+
×
5521
                                "node(id=%d): %w", dbNode.ID, err)
×
5522
                }
×
5523

5524
                return processNode(ctx, dbNode.ID, node)
×
5525
        }
5526

5527
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5528
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5529
                collectFunc, batchQueryFunc, processItem,
×
5530
        )
×
5531
}
5532

5533
// forEachChannelWithPolicies executes a paginated query to process each channel
5534
// with policies in the graph.
5535
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5536
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5537
                *models.ChannelEdgePolicy,
5538
                *models.ChannelEdgePolicy) error) error {
×
5539

×
5540
        type channelBatchIDs struct {
×
5541
                channelID int64
×
5542
                policyIDs []int64
×
5543
        }
×
5544

×
5545
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5546
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5547
                error) {
×
5548

×
5549
                return db.ListChannelsWithPoliciesPaginated(
×
5550
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5551
                                Version: int16(lnwire.GossipVersion1),
×
5552
                                ID:      lastID,
×
5553
                                Limit:   limit,
×
5554
                        },
×
5555
                )
×
5556
        }
×
5557

5558
        extractPageCursor := func(
×
5559
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5560

×
5561
                return row.GraphChannel.ID
×
5562
        }
×
5563

5564
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5565
                channelBatchIDs, error) {
×
5566

×
5567
                ids := channelBatchIDs{
×
5568
                        channelID: row.GraphChannel.ID,
×
5569
                }
×
5570

×
5571
                // Extract policy IDs from the row.
×
5572
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5573
                if err != nil {
×
5574
                        return ids, err
×
5575
                }
×
5576

5577
                if dbPol1 != nil {
×
5578
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5579
                }
×
5580
                if dbPol2 != nil {
×
5581
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5582
                }
×
5583

5584
                return ids, nil
×
5585
        }
5586

5587
        batchDataFunc := func(ctx context.Context,
×
5588
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5589

×
5590
                // Separate channel IDs from policy IDs.
×
5591
                var (
×
5592
                        channelIDs = make([]int64, len(allIDs))
×
5593
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5594
                )
×
5595

×
5596
                for i, ids := range allIDs {
×
5597
                        channelIDs[i] = ids.channelID
×
5598
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5599
                }
×
5600

5601
                return batchLoadChannelData(
×
5602
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5603
                )
×
5604
        }
5605

5606
        processItem := func(ctx context.Context,
×
5607
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5608
                batchData *batchChannelData) error {
×
5609

×
5610
                node1, node2, err := buildNodeVertices(
×
5611
                        row.Node1Pubkey, row.Node2Pubkey,
×
5612
                )
×
5613
                if err != nil {
×
5614
                        return err
×
5615
                }
×
5616

5617
                edge, err := buildEdgeInfoWithBatchData(
×
5618
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5619
                        batchData,
×
5620
                )
×
5621
                if err != nil {
×
5622
                        return fmt.Errorf("unable to build channel info: %w",
×
5623
                                err)
×
5624
                }
×
5625

5626
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5627
                if err != nil {
×
5628
                        return err
×
5629
                }
×
5630

5631
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5632
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5633
                )
×
5634
                if err != nil {
×
5635
                        return err
×
5636
                }
×
5637

5638
                return processChannel(edge, p1, p2)
×
5639
        }
5640

5641
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5642
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5643
                collectFunc, batchDataFunc, processItem,
×
5644
        )
×
5645
}
5646

5647
// buildDirectedChannel builds a DirectedChannel instance from the provided
5648
// data.
5649
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5650
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5651
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5652
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5653

×
5654
        node1, node2, err := buildNodeVertices(
×
5655
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5656
        )
×
5657
        if err != nil {
×
5658
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5659
        }
×
5660

5661
        edge, err := buildEdgeInfoWithBatchData(
×
5662
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5663
        )
×
5664
        if err != nil {
×
5665
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5666
        }
×
5667

5668
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5669
        if err != nil {
×
5670
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5671
                        err)
×
5672
        }
×
5673

5674
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5675
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5676
                channelBatchData,
×
5677
        )
×
5678
        if err != nil {
×
5679
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5680
                        err)
×
5681
        }
×
5682

5683
        // Determine outgoing and incoming policy for this specific node.
5684
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5685
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5686
        outPolicy, inPolicy := p1, p2
×
5687
        if (p1 != nil && p1ToNode == nodeID) ||
×
5688
                (p2 != nil && p2ToNode != nodeID) {
×
5689

×
5690
                outPolicy, inPolicy = p2, p1
×
5691
        }
×
5692

5693
        // Build cached policy.
5694
        var cachedInPolicy *models.CachedEdgePolicy
×
5695
        if inPolicy != nil {
×
5696
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5697
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5698
                cachedInPolicy.ToNodeFeatures = features
×
5699
        }
×
5700

5701
        // Extract inbound fee.
5702
        var inboundFee lnwire.Fee
×
5703
        if outPolicy != nil {
×
5704
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5705
                        inboundFee = fee
×
5706
                })
×
5707
        }
5708

5709
        // Build directed channel.
5710
        directedChannel := &DirectedChannel{
×
5711
                ChannelID:    edge.ChannelID,
×
5712
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5713
                OtherNode:    edge.NodeKey2Bytes,
×
5714
                Capacity:     edge.Capacity,
×
5715
                OutPolicySet: outPolicy != nil,
×
5716
                InPolicy:     cachedInPolicy,
×
5717
                InboundFee:   inboundFee,
×
5718
        }
×
5719

×
5720
        if nodePub == edge.NodeKey2Bytes {
×
5721
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5722
        }
×
5723

5724
        return directedChannel, nil
×
5725
}
5726

5727
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5728
// provided rows. It uses batch loading for channels, policies, and nodes.
5729
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5730
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5731

×
5732
        var (
×
5733
                channelIDs = make([]int64, len(rows))
×
5734
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5735
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5736

×
5737
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5738
                nodeIDSet = make(map[int64]bool)
×
5739

×
5740
                // edges will hold the final channel edges built from the rows.
×
5741
                edges = make([]ChannelEdge, 0, len(rows))
×
5742
        )
×
5743

×
5744
        // Collect all IDs needed for batch loading.
×
5745
        for i, row := range rows {
×
5746
                channelIDs[i] = row.Channel().ID
×
5747

×
5748
                // Collect policy IDs
×
5749
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5750
                if err != nil {
×
5751
                        return nil, fmt.Errorf("unable to extract channel "+
×
5752
                                "policies: %w", err)
×
5753
                }
×
5754
                if dbPol1 != nil {
×
5755
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5756
                }
×
5757
                if dbPol2 != nil {
×
5758
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5759
                }
×
5760

5761
                var (
×
5762
                        node1ID = row.Node1().ID
×
5763
                        node2ID = row.Node2().ID
×
5764
                )
×
5765

×
5766
                // Collect unique node IDs.
×
5767
                if !nodeIDSet[node1ID] {
×
5768
                        nodeIDs = append(nodeIDs, node1ID)
×
5769
                        nodeIDSet[node1ID] = true
×
5770
                }
×
5771

5772
                if !nodeIDSet[node2ID] {
×
5773
                        nodeIDs = append(nodeIDs, node2ID)
×
5774
                        nodeIDSet[node2ID] = true
×
5775
                }
×
5776
        }
5777

5778
        // Batch the data for all the channels and policies.
5779
        channelBatchData, err := batchLoadChannelData(
×
5780
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5781
        )
×
5782
        if err != nil {
×
5783
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5784
                        "policy data: %w", err)
×
5785
        }
×
5786

5787
        // Batch the data for all the nodes.
5788
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5789
        if err != nil {
×
5790
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5791
                        err)
×
5792
        }
×
5793

5794
        // Build all channel edges using batch data.
5795
        for _, row := range rows {
×
5796
                // Build nodes using batch data.
×
5797
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5798
                if err != nil {
×
5799
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5800
                }
×
5801

5802
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5803
                if err != nil {
×
5804
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5805
                }
×
5806

5807
                // Build channel info using batch data.
5808
                channel, err := buildEdgeInfoWithBatchData(
×
5809
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5810
                        node2.PubKeyBytes, channelBatchData,
×
5811
                )
×
5812
                if err != nil {
×
5813
                        return nil, fmt.Errorf("unable to build channel "+
×
5814
                                "info: %w", err)
×
5815
                }
×
5816

5817
                // Extract and build policies using batch data.
5818
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5819
                if err != nil {
×
5820
                        return nil, fmt.Errorf("unable to extract channel "+
×
5821
                                "policies: %w", err)
×
5822
                }
×
5823

5824
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5825
                        dbPol1, dbPol2, channel.ChannelID,
×
5826
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5827
                )
×
5828
                if err != nil {
×
5829
                        return nil, fmt.Errorf("unable to build channel "+
×
5830
                                "policies: %w", err)
×
5831
                }
×
5832

5833
                edges = append(edges, ChannelEdge{
×
5834
                        Info:    channel,
×
5835
                        Policy1: p1,
×
5836
                        Policy2: p2,
×
5837
                        Node1:   node1,
×
5838
                        Node2:   node2,
×
5839
                })
×
5840
        }
5841

5842
        return edges, nil
×
5843
}
5844

5845
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5846
// instances from the provided rows using batch loading for channel data.
5847
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5848
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5849
        []*models.ChannelEdgeInfo, []int64, error) {
×
5850

×
5851
        if len(rows) == 0 {
×
5852
                return nil, nil, nil
×
5853
        }
×
5854

5855
        // Collect all the channel IDs needed for batch loading.
5856
        channelIDs := make([]int64, len(rows))
×
5857
        for i, row := range rows {
×
5858
                channelIDs[i] = row.Channel().ID
×
5859
        }
×
5860

5861
        // Batch load the channel data.
5862
        channelBatchData, err := batchLoadChannelData(
×
5863
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5864
        )
×
5865
        if err != nil {
×
5866
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5867
                        "data: %w", err)
×
5868
        }
×
5869

5870
        // Build all channel edges using batch data.
5871
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5872
        for _, row := range rows {
×
5873
                node1, node2, err := buildNodeVertices(
×
5874
                        row.Node1Pub(), row.Node2Pub(),
×
5875
                )
×
5876
                if err != nil {
×
5877
                        return nil, nil, err
×
5878
                }
×
5879

5880
                // Build channel info using batch data
5881
                info, err := buildEdgeInfoWithBatchData(
×
5882
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5883
                        channelBatchData,
×
5884
                )
×
5885
                if err != nil {
×
5886
                        return nil, nil, err
×
5887
                }
×
5888

5889
                edges = append(edges, info)
×
5890
        }
5891

5892
        return edges, channelIDs, nil
×
5893
}
5894

5895
// handleZombieMarking is a helper function that handles the logic of
5896
// marking a channel as a zombie in the database. It takes into account whether
5897
// we are in strict zombie pruning mode, and adjusts the node public keys
5898
// accordingly based on the last update timestamps of the channel policies.
5899
func handleZombieMarking(ctx context.Context, db SQLQueries,
5900
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5901
        strictZombiePruning bool, scid uint64) error {
×
5902

×
5903
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5904

×
5905
        if strictZombiePruning {
×
5906
                var e1UpdateTime, e2UpdateTime *time.Time
×
5907
                if row.Policy1LastUpdate.Valid {
×
5908
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5909
                        e1UpdateTime = &e1Time
×
5910
                }
×
5911
                if row.Policy2LastUpdate.Valid {
×
5912
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5913
                        e2UpdateTime = &e2Time
×
5914
                }
×
5915

5916
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5917
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5918
                        e2UpdateTime,
×
5919
                )
×
5920
        }
5921

5922
        return db.UpsertZombieChannel(
×
5923
                ctx, sqlc.UpsertZombieChannelParams{
×
5924
                        Version:  int16(lnwire.GossipVersion1),
×
5925
                        Scid:     channelIDToBytes(scid),
×
5926
                        NodeKey1: nodeKey1[:],
×
5927
                        NodeKey2: nodeKey2[:],
×
5928
                },
×
5929
        )
×
5930
}
5931

5932
// removePublicNodeCache takes in a list of public keys and removes the
5933
// corresponding nodes info from the cache if it exists.
5934
//
5935
// NOTE: This can safely be called without holding a lock since the lru is
5936
// thread safe.
NEW
5937
func (s *SQLStore) removePublicNodeCache(pubkeys ...[33]byte) {
×
NEW
5938
        for _, pubkey := range pubkeys {
×
NEW
5939
                s.publicNodeCache.Delete(pubkey)
×
NEW
5940
        }
×
5941
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc