• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 20343164844

18 Dec 2025 04:04PM UTC coverage: 65.149% (-0.05%) from 65.195%
20343164844

Pull #10363

github

web-flow
Merge fcfc0ec01 into 91423ee51
Pull Request #10363: graphdb: add caching for isPublicNode query

13 of 38 new or added lines in 5 files covered. (34.21%)

126 existing lines in 26 files now uncovered.

137757 of 211450 relevant lines covered (65.15%)

20758.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightninglabs/neutrino/cache"
25
        "github.com/lightninglabs/neutrino/cache/lru"
26
        "github.com/lightningnetwork/lnd/aliasmgr"
27
        "github.com/lightningnetwork/lnd/batch"
28
        "github.com/lightningnetwork/lnd/fn/v2"
29
        "github.com/lightningnetwork/lnd/graph/db/models"
30
        "github.com/lightningnetwork/lnd/lnwire"
31
        "github.com/lightningnetwork/lnd/routing/route"
32
        "github.com/lightningnetwork/lnd/sqldb"
33
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
34
        "github.com/lightningnetwork/lnd/tlv"
35
        "github.com/lightningnetwork/lnd/tor"
36
)
37

38
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
39
// execute queries against the SQL graph tables.
40
//
41
//nolint:ll,interfacebloat
42
type SQLQueries interface {
43
        /*
44
                Node queries.
45
        */
46
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
47
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
48
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
49
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
50
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
51
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
52
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
53
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
54
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
55
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
56
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
57
        DeleteNode(ctx context.Context, id int64) error
58

59
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
60
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
61
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
62
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
63

64
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
65
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
66
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
67
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
68

69
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
70
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
71
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
72
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
73
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
74

75
        /*
76
                Source node queries.
77
        */
78
        AddSourceNode(ctx context.Context, nodeID int64) error
79
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
80

81
        /*
82
                Channel queries.
83
        */
84
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
85
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
86
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
87
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
88
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
89
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
90
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
91
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
92
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
93
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
94
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
95
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
96
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
97
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
98
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
99
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
100
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
101
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
102
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
103
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
104
        DeleteChannels(ctx context.Context, ids []int64) error
105

106
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
107
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
108
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
109
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
110

111
        /*
112
                Channel Policy table queries.
113
        */
114
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
115
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
116
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
117

118
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
119
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
120
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
121

122
        /*
123
                Zombie index queries.
124
        */
125
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
126
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
127
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
128
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
129
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
130
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
131

132
        /*
133
                Prune log table queries.
134
        */
135
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
136
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
137
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
138
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
139
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
140

141
        /*
142
                Closed SCID table queries.
143
        */
144
        InsertClosedChannel(ctx context.Context, scid []byte) error
145
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
146
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
147

148
        /*
149
                Migration specific queries.
150

151
                NOTE: these should not be used in code other than migrations.
152
                Once sqldbv2 is in place, these can be removed from this struct
153
                as then migrations will have their own dedicated queries
154
                structs.
155
        */
156
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
157
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
158
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
159
}
160

161
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
162
// database operations.
163
type BatchedSQLQueries interface {
164
        SQLQueries
165
        sqldb.BatchedTx[SQLQueries]
166
}
167

168
// SQLStore is an implementation of the V1Store interface that uses a SQL
169
// database as the backend.
170
type SQLStore struct {
171
        cfg *SQLStoreConfig
172
        db  BatchedSQLQueries
173

174
        // cacheMu guards all caches (rejectCache and chanCache). If
175
        // this mutex will be acquired at the same time as the DB mutex then
176
        // the cacheMu MUST be acquired first to prevent deadlock.
177
        cacheMu     sync.RWMutex
178
        rejectCache *rejectCache
179
        chanCache   *channelCache
180

181
        publicNodeCache *lru.Cache[[33]byte, *cachedPublicNode]
182

183
        chanScheduler batch.Scheduler[SQLQueries]
184
        nodeScheduler batch.Scheduler[SQLQueries]
185

186
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
187
        srcNodeMu sync.Mutex
188
}
189

190
// cachedPublicNode represents a value that can be stored in an LRU cache. It
191
// has the Size() method which the lru cache requires.
192
type cachedPublicNode struct{}
193

194
// Size returns the size of the cache entry. We return 1 as we just want to
195
// limit the number of entries rather than their actual memory size.
NEW
196
func (c *cachedPublicNode) Size() (uint64, error) {
×
NEW
197
        return 1, nil
×
NEW
198
}
×
199

200
// A compile-time assertion to ensure that SQLStore implements the V1Store
201
// interface.
202
var _ V1Store = (*SQLStore)(nil)
203

204
// SQLStoreConfig holds the configuration for the SQLStore.
205
type SQLStoreConfig struct {
206
        // ChainHash is the genesis hash for the chain that all the gossip
207
        // messages in this store are aimed at.
208
        ChainHash chainhash.Hash
209

210
        // QueryConfig holds configuration values for SQL queries.
211
        QueryCfg *sqldb.QueryConfig
212
}
213

214
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
215
// storage backend.
216
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
217
        options ...StoreOptionModifier) (*SQLStore, error) {
×
218

×
219
        opts := DefaultOptions()
×
220
        for _, o := range options {
×
221
                o(opts)
×
222
        }
×
223

224
        if opts.NoMigration {
×
225
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
226
                        "supported for SQL stores")
×
227
        }
×
228

229
        s := &SQLStore{
×
230
                cfg:         cfg,
×
231
                db:          db,
×
232
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
233
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
NEW
234
                publicNodeCache: lru.NewCache[[33]byte, *cachedPublicNode](
×
NEW
235
                        uint64(opts.PublicNodeCacheSize),
×
NEW
236
                ),
×
NEW
237
                srcNodes: make(map[lnwire.GossipVersion]*srcNodeInfo),
×
238
        }
×
239

×
240
        s.chanScheduler = batch.NewTimeScheduler(
×
241
                db, &s.cacheMu, opts.BatchCommitInterval,
×
242
        )
×
243
        s.nodeScheduler = batch.NewTimeScheduler(
×
244
                db, nil, opts.BatchCommitInterval,
×
245
        )
×
246

×
247
        return s, nil
×
248
}
249

250
// AddNode adds a vertex/node to the graph database. If the node is not
251
// in the database from before, this will add a new, unconnected one to the
252
// graph. If it is present from before, this will update that node's
253
// information.
254
//
255
// NOTE: part of the V1Store interface.
256
func (s *SQLStore) AddNode(ctx context.Context,
257
        node *models.Node, opts ...batch.SchedulerOption) error {
×
258

×
259
        r := &batch.Request[SQLQueries]{
×
260
                Opts: batch.NewSchedulerOptions(opts...),
×
261
                Do: func(queries SQLQueries) error {
×
262
                        _, err := upsertNode(ctx, queries, node)
×
263

×
264
                        // It is possible that two of the same node
×
265
                        // announcements are both being processed in the same
×
266
                        // batch. This may case the UpsertNode conflict to
×
267
                        // be hit since we require at the db layer that the
×
268
                        // new last_update is greater than the existing
×
269
                        // last_update. We need to gracefully handle this here.
×
270
                        if errors.Is(err, sql.ErrNoRows) {
×
271
                                return nil
×
272
                        }
×
273

274
                        return err
×
275
                },
276
        }
277

278
        return s.nodeScheduler.Execute(ctx, r)
×
279
}
280

281
// FetchNode attempts to look up a target node by its identity public
282
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
283
// returned.
284
//
285
// NOTE: part of the V1Store interface.
286
func (s *SQLStore) FetchNode(ctx context.Context,
287
        pubKey route.Vertex) (*models.Node, error) {
×
288

×
289
        var node *models.Node
×
290
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
291
                var err error
×
292
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
293

×
294
                return err
×
295
        }, sqldb.NoOpReset)
×
296
        if err != nil {
×
297
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
298
        }
×
299

300
        return node, nil
×
301
}
302

303
// HasNode determines if the graph has a vertex identified by the
304
// target node identity public key. If the node exists in the database, a
305
// timestamp of when the data for the node was lasted updated is returned along
306
// with a true boolean. Otherwise, an empty time.Time is returned with a false
307
// boolean.
308
//
309
// NOTE: part of the V1Store interface.
310
func (s *SQLStore) HasNode(ctx context.Context,
311
        pubKey [33]byte) (time.Time, bool, error) {
×
312

×
313
        var (
×
314
                exists     bool
×
315
                lastUpdate time.Time
×
316
        )
×
317
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
318
                dbNode, err := db.GetNodeByPubKey(
×
319
                        ctx, sqlc.GetNodeByPubKeyParams{
×
320
                                Version: int16(lnwire.GossipVersion1),
×
321
                                PubKey:  pubKey[:],
×
322
                        },
×
323
                )
×
324
                if errors.Is(err, sql.ErrNoRows) {
×
325
                        return nil
×
326
                } else if err != nil {
×
327
                        return fmt.Errorf("unable to fetch node: %w", err)
×
328
                }
×
329

330
                exists = true
×
331

×
332
                if dbNode.LastUpdate.Valid {
×
333
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
334
                }
×
335

336
                return nil
×
337
        }, sqldb.NoOpReset)
338
        if err != nil {
×
339
                return time.Time{}, false,
×
340
                        fmt.Errorf("unable to fetch node: %w", err)
×
341
        }
×
342

343
        return lastUpdate, exists, nil
×
344
}
345

346
// AddrsForNode returns all known addresses for the target node public key
347
// that the graph DB is aware of. The returned boolean indicates if the
348
// given node is unknown to the graph DB or not.
349
//
350
// NOTE: part of the V1Store interface.
351
func (s *SQLStore) AddrsForNode(ctx context.Context,
352
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
353

×
354
        var (
×
355
                addresses []net.Addr
×
356
                known     bool
×
357
        )
×
358
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
359
                // First, check if the node exists and get its DB ID if it
×
360
                // does.
×
361
                dbID, err := db.GetNodeIDByPubKey(
×
362
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
363
                                Version: int16(lnwire.GossipVersion1),
×
364
                                PubKey:  nodePub.SerializeCompressed(),
×
365
                        },
×
366
                )
×
367
                if errors.Is(err, sql.ErrNoRows) {
×
368
                        return nil
×
369
                }
×
370

371
                known = true
×
372

×
373
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
374
                if err != nil {
×
375
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
376
                                err)
×
377
                }
×
378

379
                return nil
×
380
        }, sqldb.NoOpReset)
381
        if err != nil {
×
382
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
383
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
384
        }
×
385

386
        return known, addresses, nil
×
387
}
388

389
// DeleteNode starts a new database transaction to remove a vertex/node
390
// from the database according to the node's public key.
391
//
392
// NOTE: part of the V1Store interface.
393
func (s *SQLStore) DeleteNode(ctx context.Context,
394
        pubKey route.Vertex) error {
×
395

×
396
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
397
                res, err := db.DeleteNodeByPubKey(
×
398
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
399
                                Version: int16(lnwire.GossipVersion1),
×
400
                                PubKey:  pubKey[:],
×
401
                        },
×
402
                )
×
403
                if err != nil {
×
404
                        return err
×
405
                }
×
406

407
                rows, err := res.RowsAffected()
×
408
                if err != nil {
×
409
                        return err
×
410
                }
×
411

412
                if rows == 0 {
×
413
                        return ErrGraphNodeNotFound
×
414
                } else if rows > 1 {
×
415
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
416
                }
×
417

418
                return err
×
419
        }, sqldb.NoOpReset)
420
        if err != nil {
×
421
                return fmt.Errorf("unable to delete node: %w", err)
×
422
        }
×
423

424
        return nil
×
425
}
426

427
// FetchNodeFeatures returns the features of the given node. If no features are
428
// known for the node, an empty feature vector is returned.
429
//
430
// NOTE: this is part of the graphdb.NodeTraverser interface.
431
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
432
        *lnwire.FeatureVector, error) {
×
433

×
434
        ctx := context.TODO()
×
435

×
436
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
437
}
×
438

439
// DisabledChannelIDs returns the channel ids of disabled channels.
440
// A channel is disabled when two of the associated ChanelEdgePolicies
441
// have their disabled bit on.
442
//
443
// NOTE: part of the V1Store interface.
444
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
445
        var (
×
446
                ctx     = context.TODO()
×
447
                chanIDs []uint64
×
448
        )
×
449
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
450
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
451
                if err != nil {
×
452
                        return fmt.Errorf("unable to fetch disabled "+
×
453
                                "channels: %w", err)
×
454
                }
×
455

456
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
457

×
458
                return nil
×
459
        }, sqldb.NoOpReset)
460
        if err != nil {
×
461
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
462
                        err)
×
463
        }
×
464

465
        return chanIDs, nil
×
466
}
467

468
// LookupAlias attempts to return the alias as advertised by the target node.
469
//
470
// NOTE: part of the V1Store interface.
471
func (s *SQLStore) LookupAlias(ctx context.Context,
472
        pub *btcec.PublicKey) (string, error) {
×
473

×
474
        var alias string
×
475
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
476
                dbNode, err := db.GetNodeByPubKey(
×
477
                        ctx, sqlc.GetNodeByPubKeyParams{
×
478
                                Version: int16(lnwire.GossipVersion1),
×
479
                                PubKey:  pub.SerializeCompressed(),
×
480
                        },
×
481
                )
×
482
                if errors.Is(err, sql.ErrNoRows) {
×
483
                        return ErrNodeAliasNotFound
×
484
                } else if err != nil {
×
485
                        return fmt.Errorf("unable to fetch node: %w", err)
×
486
                }
×
487

488
                if !dbNode.Alias.Valid {
×
489
                        return ErrNodeAliasNotFound
×
490
                }
×
491

492
                alias = dbNode.Alias.String
×
493

×
494
                return nil
×
495
        }, sqldb.NoOpReset)
496
        if err != nil {
×
497
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
498
        }
×
499

500
        return alias, nil
×
501
}
502

503
// SourceNode returns the source node of the graph. The source node is treated
504
// as the center node within a star-graph. This method may be used to kick off
505
// a path finding algorithm in order to explore the reachability of another
506
// node based off the source node.
507
//
508
// NOTE: part of the V1Store interface.
509
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
510
        error) {
×
511

×
512
        var node *models.Node
×
513
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
514
                _, nodePub, err := s.getSourceNode(
×
515
                        ctx, db, lnwire.GossipVersion1,
×
516
                )
×
517
                if err != nil {
×
518
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
519
                                err)
×
520
                }
×
521

522
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
523

×
524
                return err
×
525
        }, sqldb.NoOpReset)
526
        if err != nil {
×
527
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
528
        }
×
529

530
        return node, nil
×
531
}
532

533
// SetSourceNode sets the source node within the graph database. The source
534
// node is to be used as the center of a star-graph within path finding
535
// algorithms.
536
//
537
// NOTE: part of the V1Store interface.
538
func (s *SQLStore) SetSourceNode(ctx context.Context,
539
        node *models.Node) error {
×
540

×
541
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
542
                // For the source node, we use a less strict upsert that allows
×
543
                // updates even when the timestamp hasn't changed. This handles
×
544
                // the race condition where multiple goroutines (e.g.,
×
545
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
546
                // same old timestamp, independently increment it, and try to
×
547
                // write concurrently. We want all parameter changes to persist,
×
548
                // even if timestamps collide.
×
549
                id, err := upsertSourceNode(ctx, db, node)
×
550
                if err != nil {
×
551
                        return fmt.Errorf("unable to upsert source node: %w",
×
552
                                err)
×
553
                }
×
554

555
                // Make sure that if a source node for this version is already
556
                // set, then the ID is the same as the one we are about to set.
557
                dbSourceNodeID, _, err := s.getSourceNode(
×
558
                        ctx, db, lnwire.GossipVersion1,
×
559
                )
×
560
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
561
                        return fmt.Errorf("unable to fetch source node: %w",
×
562
                                err)
×
563
                } else if err == nil {
×
564
                        if dbSourceNodeID != id {
×
565
                                return fmt.Errorf("v1 source node already "+
×
566
                                        "set to a different node: %d vs %d",
×
567
                                        dbSourceNodeID, id)
×
568
                        }
×
569

570
                        return nil
×
571
                }
572

573
                return db.AddSourceNode(ctx, id)
×
574
        }, sqldb.NoOpReset)
575
}
576

577
// NodeUpdatesInHorizon returns all the known lightning node which have an
578
// update timestamp within the passed range. This method can be used by two
579
// nodes to quickly determine if they have the same set of up to date node
580
// announcements.
581
//
582
// NOTE: This is part of the V1Store interface.
583
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
584
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
585

×
586
        cfg := defaultIteratorConfig()
×
587
        for _, opt := range opts {
×
588
                opt(cfg)
×
589
        }
×
590

591
        return func(yield func(*models.Node, error) bool) {
×
592
                var (
×
593
                        ctx            = context.TODO()
×
594
                        lastUpdateTime sql.NullInt64
×
595
                        lastPubKey     = make([]byte, 33)
×
596
                        hasMore        = true
×
597
                )
×
598

×
599
                // Each iteration, we'll read a batch amount of nodes, yield
×
600
                // them, then decide is we have more or not.
×
601
                for hasMore {
×
602
                        var batch []*models.Node
×
603

×
604
                        //nolint:ll
×
605
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
606
                                //nolint:ll
×
607
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
608
                                        StartTime: sqldb.SQLInt64(
×
609
                                                startTime.Unix(),
×
610
                                        ),
×
611
                                        EndTime: sqldb.SQLInt64(
×
612
                                                endTime.Unix(),
×
613
                                        ),
×
614
                                        LastUpdate: lastUpdateTime,
×
615
                                        LastPubKey: lastPubKey,
×
616
                                        OnlyPublic: sql.NullBool{
×
617
                                                Bool:  cfg.iterPublicNodes,
×
618
                                                Valid: true,
×
619
                                        },
×
620
                                        MaxResults: sqldb.SQLInt32(
×
621
                                                cfg.nodeUpdateIterBatchSize,
×
622
                                        ),
×
623
                                }
×
624
                                rows, err := db.GetNodesByLastUpdateRange(
×
625
                                        ctx, params,
×
626
                                )
×
627
                                if err != nil {
×
628
                                        return err
×
629
                                }
×
630

631
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
632

×
633
                                err = forEachNodeInBatch(
×
634
                                        ctx, s.cfg.QueryCfg, db, rows,
×
635
                                        func(_ int64, node *models.Node) error {
×
636
                                                batch = append(batch, node)
×
637

×
638
                                                // Update pagination cursors
×
639
                                                // based on the last processed
×
640
                                                // node.
×
641
                                                lastUpdateTime = sql.NullInt64{
×
642
                                                        Int64: node.LastUpdate.
×
643
                                                                Unix(),
×
644
                                                        Valid: true,
×
645
                                                }
×
646
                                                lastPubKey = node.PubKeyBytes[:]
×
647

×
648
                                                return nil
×
649
                                        },
×
650
                                )
651
                                if err != nil {
×
652
                                        return fmt.Errorf("unable to build "+
×
653
                                                "nodes: %w", err)
×
654
                                }
×
655

656
                                return nil
×
657
                        }, func() {
×
658
                                batch = []*models.Node{}
×
659
                        })
×
660

661
                        if err != nil {
×
662
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
663
                                        "error: %v", err)
×
664

×
665
                                yield(&models.Node{}, err)
×
666

×
667
                                return
×
668
                        }
×
669

670
                        for _, node := range batch {
×
671
                                if !yield(node, nil) {
×
672
                                        return
×
673
                                }
×
674
                        }
675

676
                        // If the batch didn't yield anything, then we're done.
677
                        if len(batch) == 0 {
×
678
                                break
×
679
                        }
680
                }
681
        }
682
}
683

684
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
685
// undirected edge from the two target nodes are created. The information stored
686
// denotes the static attributes of the channel, such as the channelID, the keys
687
// involved in creation of the channel, and the set of features that the channel
688
// supports. The chanPoint and chanID are used to uniquely identify the edge
689
// globally within the database.
690
//
691
// NOTE: part of the V1Store interface.
692
func (s *SQLStore) AddChannelEdge(ctx context.Context,
693
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
694

×
695
        var alreadyExists bool
×
696
        r := &batch.Request[SQLQueries]{
×
697
                Opts: batch.NewSchedulerOptions(opts...),
×
698
                Reset: func() {
×
699
                        alreadyExists = false
×
700
                },
×
701
                Do: func(tx SQLQueries) error {
×
702
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
703

×
704
                        // Make sure that the channel doesn't already exist. We
×
705
                        // do this explicitly instead of relying on catching a
×
706
                        // unique constraint error because relying on SQL to
×
707
                        // throw that error would abort the entire batch of
×
708
                        // transactions.
×
709
                        _, err := tx.GetChannelBySCID(
×
710
                                ctx, sqlc.GetChannelBySCIDParams{
×
711
                                        Scid:    chanIDB,
×
712
                                        Version: int16(lnwire.GossipVersion1),
×
713
                                },
×
714
                        )
×
715
                        if err == nil {
×
716
                                alreadyExists = true
×
717
                                return nil
×
718
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
719
                                return fmt.Errorf("unable to fetch channel: %w",
×
720
                                        err)
×
721
                        }
×
722

723
                        return insertChannel(ctx, tx, edge)
×
724
                },
725
                OnCommit: func(err error) error {
×
726
                        switch {
×
727
                        case err != nil:
×
728
                                return err
×
729
                        case alreadyExists:
×
730
                                return ErrEdgeAlreadyExist
×
731
                        default:
×
732
                                s.rejectCache.remove(edge.ChannelID)
×
733
                                s.chanCache.remove(edge.ChannelID)
×
734
                                return nil
×
735
                        }
736
                },
737
        }
738

739
        return s.chanScheduler.Execute(ctx, r)
×
740
}
741

742
// HighestChanID returns the "highest" known channel ID in the channel graph.
743
// This represents the "newest" channel from the PoV of the chain. This method
744
// can be used by peers to quickly determine if their graphs are in sync.
745
//
746
// NOTE: This is part of the V1Store interface.
747
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
748
        var highestChanID uint64
×
749
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
750
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
751
                if errors.Is(err, sql.ErrNoRows) {
×
752
                        return nil
×
753
                } else if err != nil {
×
754
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
755
                                err)
×
756
                }
×
757

758
                highestChanID = byteOrder.Uint64(chanID)
×
759

×
760
                return nil
×
761
        }, sqldb.NoOpReset)
762
        if err != nil {
×
763
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
764
        }
×
765

766
        return highestChanID, nil
×
767
}
768

769
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
770
// within the database for the referenced channel. The `flags` attribute within
771
// the ChannelEdgePolicy determines which of the directed edges are being
772
// updated. If the flag is 1, then the first node's information is being
773
// updated, otherwise it's the second node's information. The node ordering is
774
// determined by the lexicographical ordering of the identity public keys of the
775
// nodes on either side of the channel.
776
//
777
// NOTE: part of the V1Store interface.
778
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
779
        edge *models.ChannelEdgePolicy,
780
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
781

×
782
        var (
×
783
                isUpdate1    bool
×
784
                edgeNotFound bool
×
785
                from, to     route.Vertex
×
786
        )
×
787

×
788
        r := &batch.Request[SQLQueries]{
×
789
                Opts: batch.NewSchedulerOptions(opts...),
×
790
                Reset: func() {
×
791
                        isUpdate1 = false
×
792
                        edgeNotFound = false
×
793
                },
×
794
                Do: func(tx SQLQueries) error {
×
795
                        var err error
×
796
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
797
                                ctx, tx, edge,
×
798
                        )
×
799
                        // It is possible that two of the same policy
×
800
                        // announcements are both being processed in the same
×
801
                        // batch. This may case the UpsertEdgePolicy conflict to
×
802
                        // be hit since we require at the db layer that the
×
803
                        // new last_update is greater than the existing
×
804
                        // last_update. We need to gracefully handle this here.
×
805
                        if errors.Is(err, sql.ErrNoRows) {
×
806
                                return nil
×
807
                        } else if err != nil {
×
808
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
809
                        }
×
810

811
                        // Silence ErrEdgeNotFound so that the batch can
812
                        // succeed, but propagate the error via local state.
813
                        if errors.Is(err, ErrEdgeNotFound) {
×
814
                                edgeNotFound = true
×
815
                                return nil
×
816
                        }
×
817

818
                        return err
×
819
                },
820
                OnCommit: func(err error) error {
×
821
                        switch {
×
822
                        case err != nil:
×
823
                                return err
×
824
                        case edgeNotFound:
×
825
                                return ErrEdgeNotFound
×
826
                        default:
×
827
                                s.updateEdgeCache(edge, isUpdate1)
×
828
                                return nil
×
829
                        }
830
                },
831
        }
832

833
        err := s.chanScheduler.Execute(ctx, r)
×
834

×
835
        return from, to, err
×
836
}
837

838
// updateEdgeCache updates our reject and channel caches with the new
839
// edge policy information.
840
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
841
        isUpdate1 bool) {
×
842

×
843
        // If an entry for this channel is found in reject cache, we'll modify
×
844
        // the entry with the updated timestamp for the direction that was just
×
845
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
846
        // during the next query for this edge.
×
847
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
848
                if isUpdate1 {
×
849
                        entry.upd1Time = e.LastUpdate.Unix()
×
850
                } else {
×
851
                        entry.upd2Time = e.LastUpdate.Unix()
×
852
                }
×
853
                s.rejectCache.insert(e.ChannelID, entry)
×
854
        }
855

856
        // If an entry for this channel is found in channel cache, we'll modify
857
        // the entry with the updated policy for the direction that was just
858
        // written. If the edge doesn't exist, we'll defer loading the info and
859
        // policies and lazily read from disk during the next query.
860
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
861
                if isUpdate1 {
×
862
                        channel.Policy1 = e
×
863
                } else {
×
864
                        channel.Policy2 = e
×
865
                }
×
866
                s.chanCache.insert(e.ChannelID, channel)
×
867
        }
868
}
869

870
// ForEachSourceNodeChannel iterates through all channels of the source node,
871
// executing the passed callback on each. The call-back is provided with the
872
// channel's outpoint, whether we have a policy for the channel and the channel
873
// peer's node information.
874
//
875
// NOTE: part of the V1Store interface.
876
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
877
        cb func(chanPoint wire.OutPoint, havePolicy bool,
878
                otherNode *models.Node) error, reset func()) error {
×
879

×
880
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
881
                nodeID, nodePub, err := s.getSourceNode(
×
882
                        ctx, db, lnwire.GossipVersion1,
×
883
                )
×
884
                if err != nil {
×
885
                        return fmt.Errorf("unable to fetch source node: %w",
×
886
                                err)
×
887
                }
×
888

889
                return forEachNodeChannel(
×
890
                        ctx, db, s.cfg, nodeID,
×
891
                        func(info *models.ChannelEdgeInfo,
×
892
                                outPolicy *models.ChannelEdgePolicy,
×
893
                                _ *models.ChannelEdgePolicy) error {
×
894

×
895
                                // Fetch the other node.
×
896
                                var (
×
897
                                        otherNodePub [33]byte
×
898
                                        node1        = info.NodeKey1Bytes
×
899
                                        node2        = info.NodeKey2Bytes
×
900
                                )
×
901
                                switch {
×
902
                                case bytes.Equal(node1[:], nodePub[:]):
×
903
                                        otherNodePub = node2
×
904
                                case bytes.Equal(node2[:], nodePub[:]):
×
905
                                        otherNodePub = node1
×
906
                                default:
×
907
                                        return fmt.Errorf("node not " +
×
908
                                                "participating in this channel")
×
909
                                }
910

911
                                _, otherNode, err := getNodeByPubKey(
×
912
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
913
                                )
×
914
                                if err != nil {
×
915
                                        return fmt.Errorf("unable to fetch "+
×
916
                                                "other node(%x): %w",
×
917
                                                otherNodePub, err)
×
918
                                }
×
919

920
                                return cb(
×
921
                                        info.ChannelPoint, outPolicy != nil,
×
922
                                        otherNode,
×
923
                                )
×
924
                        },
925
                )
926
        }, reset)
927
}
928

929
// ForEachNode iterates through all the stored vertices/nodes in the graph,
930
// executing the passed callback with each node encountered. If the callback
931
// returns an error, then the transaction is aborted and the iteration stops
932
// early.
933
//
934
// NOTE: part of the V1Store interface.
935
func (s *SQLStore) ForEachNode(ctx context.Context,
936
        cb func(node *models.Node) error, reset func()) error {
×
937

×
938
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
939
                return forEachNodePaginated(
×
940
                        ctx, s.cfg.QueryCfg, db,
×
941
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
942
                                node *models.Node) error {
×
943

×
944
                                return cb(node)
×
945
                        },
×
946
                )
947
        }, reset)
948
}
949

950
// ForEachNodeDirectedChannel iterates through all channels of a given node,
951
// executing the passed callback on the directed edge representing the channel
952
// and its incoming policy. If the callback returns an error, then the iteration
953
// is halted with the error propagated back up to the caller.
954
//
955
// Unknown policies are passed into the callback as nil values.
956
//
957
// NOTE: this is part of the graphdb.NodeTraverser interface.
958
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
959
        cb func(channel *DirectedChannel) error, reset func()) error {
×
960

×
961
        var ctx = context.TODO()
×
962

×
963
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
964
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
965
        }, reset)
×
966
}
967

968
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
969
// graph, executing the passed callback with each node encountered. If the
970
// callback returns an error, then the transaction is aborted and the iteration
971
// stops early.
972
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
973
        cb func(route.Vertex, *lnwire.FeatureVector) error,
974
        reset func()) error {
×
975

×
976
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
977
                return forEachNodeCacheable(
×
978
                        ctx, s.cfg.QueryCfg, db,
×
979
                        func(_ int64, nodePub route.Vertex,
×
980
                                features *lnwire.FeatureVector) error {
×
981

×
982
                                return cb(nodePub, features)
×
983
                        },
×
984
                )
985
        }, reset)
986
        if err != nil {
×
987
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
988
        }
×
989

990
        return nil
×
991
}
992

993
// ForEachNodeChannel iterates through all channels of the given node,
994
// executing the passed callback with an edge info structure and the policies
995
// of each end of the channel. The first edge policy is the outgoing edge *to*
996
// the connecting node, while the second is the incoming edge *from* the
997
// connecting node. If the callback returns an error, then the iteration is
998
// halted with the error propagated back up to the caller.
999
//
1000
// Unknown policies are passed into the callback as nil values.
1001
//
1002
// NOTE: part of the V1Store interface.
1003
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
1004
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1005
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1006

×
1007
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1008
                dbNode, err := db.GetNodeByPubKey(
×
1009
                        ctx, sqlc.GetNodeByPubKeyParams{
×
1010
                                Version: int16(lnwire.GossipVersion1),
×
1011
                                PubKey:  nodePub[:],
×
1012
                        },
×
1013
                )
×
1014
                if errors.Is(err, sql.ErrNoRows) {
×
1015
                        return nil
×
1016
                } else if err != nil {
×
1017
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1018
                }
×
1019

1020
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1021
        }, reset)
1022
}
1023

1024
// extractMaxUpdateTime returns the maximum of the two policy update times.
1025
// This is used for pagination cursor tracking.
1026
func extractMaxUpdateTime(
1027
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1028

×
1029
        switch {
×
1030
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1031
                return max(row.Policy1LastUpdate.Int64,
×
1032
                        row.Policy2LastUpdate.Int64)
×
1033
        case row.Policy1LastUpdate.Valid:
×
1034
                return row.Policy1LastUpdate.Int64
×
1035
        case row.Policy2LastUpdate.Valid:
×
1036
                return row.Policy2LastUpdate.Int64
×
1037
        default:
×
1038
                return 0
×
1039
        }
1040
}
1041

1042
// buildChannelFromRow constructs a ChannelEdge from a database row.
1043
// This includes building the nodes, channel info, and policies.
1044
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1045
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1046

×
1047
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1048
        if err != nil {
×
1049
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1050
                        err)
×
1051
        }
×
1052

1053
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1054
        if err != nil {
×
1055
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1056
                        err)
×
1057
        }
×
1058

1059
        channel, err := getAndBuildEdgeInfo(
×
1060
                ctx, s.cfg, db,
×
1061
                row.GraphChannel, node1.PubKeyBytes,
×
1062
                node2.PubKeyBytes,
×
1063
        )
×
1064
        if err != nil {
×
1065
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1066
                        "channel info: %w", err)
×
1067
        }
×
1068

1069
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1070
        if err != nil {
×
1071
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1072
                        "channel policies: %w", err)
×
1073
        }
×
1074

1075
        p1, p2, err := getAndBuildChanPolicies(
×
1076
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1077
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1078
        )
×
1079
        if err != nil {
×
1080
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1081
                        "channel policies: %w", err)
×
1082
        }
×
1083

1084
        return ChannelEdge{
×
1085
                Info:    channel,
×
1086
                Policy1: p1,
×
1087
                Policy2: p2,
×
1088
                Node1:   node1,
×
1089
                Node2:   node2,
×
1090
        }, nil
×
1091
}
1092

1093
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1094
// This method acquires the cache lock only once for the entire batch.
1095
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1096
        if len(edgesToCache) == 0 {
×
1097
                return
×
1098
        }
×
1099

1100
        s.cacheMu.Lock()
×
1101
        defer s.cacheMu.Unlock()
×
1102

×
1103
        for chanID, edge := range edgesToCache {
×
1104
                s.chanCache.insert(chanID, edge)
×
1105
        }
×
1106
}
1107

1108
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1109
// one edge that has an update timestamp within the specified horizon.
1110
//
1111
// Iterator Lifecycle:
1112
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1113
// 2. Query batch of channels with policies in time range
1114
// 3. For each channel: check if seen, check cache, or build from DB
1115
// 4. Yield channels to caller
1116
// 5. Update cache after successful batch
1117
// 6. Repeat with updated pagination cursor until no more results
1118
//
1119
// NOTE: This is part of the V1Store interface.
1120
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1121
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1122

×
1123
        // Apply options.
×
1124
        cfg := defaultIteratorConfig()
×
1125
        for _, opt := range opts {
×
1126
                opt(cfg)
×
1127
        }
×
1128

1129
        return func(yield func(ChannelEdge, error) bool) {
×
1130
                var (
×
1131
                        ctx            = context.TODO()
×
1132
                        edgesSeen      = make(map[uint64]struct{})
×
1133
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1134
                        hits           int
×
1135
                        total          int
×
1136
                        lastUpdateTime sql.NullInt64
×
1137
                        lastID         sql.NullInt64
×
1138
                        hasMore        = true
×
1139
                )
×
1140

×
1141
                // Each iteration, we'll read a batch amount of channel updates
×
1142
                // (consulting the cache along the way), yield them, then loop
×
1143
                // back to decide if we have any more updates to read out.
×
1144
                for hasMore {
×
1145
                        var batch []ChannelEdge
×
1146

×
1147
                        // Acquire read lock before starting transaction to
×
1148
                        // ensure consistent lock ordering (cacheMu -> DB) and
×
1149
                        // prevent deadlock with write operations.
×
1150
                        s.cacheMu.RLock()
×
1151

×
1152
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1153
                                func(db SQLQueries) error {
×
1154
                                        //nolint:ll
×
1155
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1156
                                                Version: int16(lnwire.GossipVersion1),
×
1157
                                                StartTime: sqldb.SQLInt64(
×
1158
                                                        startTime.Unix(),
×
1159
                                                ),
×
1160
                                                EndTime: sqldb.SQLInt64(
×
1161
                                                        endTime.Unix(),
×
1162
                                                ),
×
1163
                                                LastUpdateTime: lastUpdateTime,
×
1164
                                                LastID:         lastID,
×
1165
                                                MaxResults: sql.NullInt32{
×
1166
                                                        Int32: int32(
×
1167
                                                                cfg.chanUpdateIterBatchSize,
×
1168
                                                        ),
×
1169
                                                        Valid: true,
×
1170
                                                },
×
1171
                                        }
×
1172
                                        //nolint:ll
×
1173
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1174
                                                ctx, params,
×
1175
                                        )
×
1176
                                        if err != nil {
×
1177
                                                return err
×
1178
                                        }
×
1179

1180
                                        //nolint:ll
1181
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1182

×
1183
                                        //nolint:ll
×
1184
                                        for _, row := range rows {
×
1185
                                                lastUpdateTime = sql.NullInt64{
×
1186
                                                        Int64: extractMaxUpdateTime(row),
×
1187
                                                        Valid: true,
×
1188
                                                }
×
1189
                                                lastID = sql.NullInt64{
×
1190
                                                        Int64: row.GraphChannel.ID,
×
1191
                                                        Valid: true,
×
1192
                                                }
×
1193

×
1194
                                                // Skip if we've already
×
1195
                                                // processed this channel.
×
1196
                                                chanIDInt := byteOrder.Uint64(
×
1197
                                                        row.GraphChannel.Scid,
×
1198
                                                )
×
1199
                                                _, ok := edgesSeen[chanIDInt]
×
1200
                                                if ok {
×
1201
                                                        continue
×
1202
                                                }
1203

1204
                                                // Check cache (we already hold
1205
                                                // shared read lock).
1206
                                                channel, ok := s.chanCache.get(
×
1207
                                                        chanIDInt,
×
1208
                                                )
×
1209
                                                if ok {
×
1210
                                                        hits++
×
1211
                                                        total++
×
1212
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1213
                                                        batch = append(batch, channel)
×
1214

×
1215
                                                        continue
×
1216
                                                }
1217

1218
                                                chanEdge, err := s.buildChannelFromRow(
×
1219
                                                        ctx, db, row,
×
1220
                                                )
×
1221
                                                if err != nil {
×
1222
                                                        return err
×
1223
                                                }
×
1224

1225
                                                edgesSeen[chanIDInt] = struct{}{}
×
1226
                                                edgesToCache[chanIDInt] = chanEdge
×
1227

×
1228
                                                batch = append(batch, chanEdge)
×
1229

×
1230
                                                total++
×
1231
                                        }
1232

1233
                                        return nil
×
1234
                                }, func() {
×
1235
                                        batch = nil
×
1236
                                        edgesSeen = make(map[uint64]struct{})
×
1237
                                        edgesToCache = make(
×
1238
                                                map[uint64]ChannelEdge,
×
1239
                                        )
×
1240
                                })
×
1241

1242
                        // Release read lock after transaction completes.
1243
                        s.cacheMu.RUnlock()
×
1244

×
1245
                        if err != nil {
×
1246
                                log.Errorf("ChanUpdatesInHorizon "+
×
1247
                                        "batch error: %v", err)
×
1248

×
1249
                                yield(ChannelEdge{}, err)
×
1250

×
1251
                                return
×
1252
                        }
×
1253

1254
                        for _, edge := range batch {
×
1255
                                if !yield(edge, nil) {
×
1256
                                        return
×
1257
                                }
×
1258
                        }
1259

1260
                        // Update cache after successful batch yield, setting
1261
                        // the cache lock only once for the entire batch.
1262
                        s.updateChanCacheBatch(edgesToCache)
×
1263
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1264

×
1265
                        // If the batch didn't yield anything, then we're done.
×
1266
                        if len(batch) == 0 {
×
1267
                                break
×
1268
                        }
1269
                }
1270

1271
                if total > 0 {
×
1272
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1273
                                "%.2f (%d/%d)",
×
1274
                                float64(hits)*100/float64(total), hits, total)
×
1275
                } else {
×
1276
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1277
                                "in horizon (%s, %s)", startTime, endTime)
×
1278
                }
×
1279
        }
1280
}
1281

1282
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1283
// data to the call-back. If withAddrs is true, then the call-back will also be
1284
// provided with the addresses associated with the node. The address retrieval
1285
// result in an additional round-trip to the database, so it should only be used
1286
// if the addresses are actually needed.
1287
//
1288
// NOTE: part of the V1Store interface.
1289
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1290
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1291
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1292

×
1293
        type nodeCachedBatchData struct {
×
1294
                features      map[int64][]int
×
1295
                addrs         map[int64][]nodeAddress
×
1296
                chanBatchData *batchChannelData
×
1297
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1298
        }
×
1299

×
1300
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1301
                // pageQueryFunc is used to query the next page of nodes.
×
1302
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1303
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1304

×
1305
                        return db.ListNodeIDsAndPubKeys(
×
1306
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1307
                                        Version: int16(lnwire.GossipVersion1),
×
1308
                                        ID:      lastID,
×
1309
                                        Limit:   limit,
×
1310
                                },
×
1311
                        )
×
1312
                }
×
1313

1314
                // batchDataFunc is then used to batch load the data required
1315
                // for each page of nodes.
1316
                batchDataFunc := func(ctx context.Context,
×
1317
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1318

×
1319
                        // Batch load node features.
×
1320
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1321
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1322
                        )
×
1323
                        if err != nil {
×
1324
                                return nil, fmt.Errorf("unable to batch load "+
×
1325
                                        "node features: %w", err)
×
1326
                        }
×
1327

1328
                        // Maybe fetch the node's addresses if requested.
1329
                        var nodeAddrs map[int64][]nodeAddress
×
1330
                        if withAddrs {
×
1331
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1332
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1333
                                )
×
1334
                                if err != nil {
×
1335
                                        return nil, fmt.Errorf("unable to "+
×
1336
                                                "batch load node "+
×
1337
                                                "addresses: %w", err)
×
1338
                                }
×
1339
                        }
1340

1341
                        // Batch load ALL unique channels for ALL nodes in this
1342
                        // page.
1343
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1344
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1345
                                        Version:  int16(lnwire.GossipVersion1),
×
1346
                                        Node1Ids: nodeIDs,
×
1347
                                        Node2Ids: nodeIDs,
×
1348
                                },
×
1349
                        )
×
1350
                        if err != nil {
×
1351
                                return nil, fmt.Errorf("unable to batch "+
×
1352
                                        "fetch channels for nodes: %w", err)
×
1353
                        }
×
1354

1355
                        // Deduplicate channels and collect IDs.
1356
                        var (
×
1357
                                allChannelIDs []int64
×
1358
                                allPolicyIDs  []int64
×
1359
                        )
×
1360
                        uniqueChannels := make(
×
1361
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1362
                        )
×
1363

×
1364
                        for _, channel := range allChannels {
×
1365
                                channelID := channel.GraphChannel.ID
×
1366

×
1367
                                // Only process each unique channel once.
×
1368
                                _, exists := uniqueChannels[channelID]
×
1369
                                if exists {
×
1370
                                        continue
×
1371
                                }
1372

1373
                                uniqueChannels[channelID] = channel
×
1374
                                allChannelIDs = append(allChannelIDs, channelID)
×
1375

×
1376
                                if channel.Policy1ID.Valid {
×
1377
                                        allPolicyIDs = append(
×
1378
                                                allPolicyIDs,
×
1379
                                                channel.Policy1ID.Int64,
×
1380
                                        )
×
1381
                                }
×
1382
                                if channel.Policy2ID.Valid {
×
1383
                                        allPolicyIDs = append(
×
1384
                                                allPolicyIDs,
×
1385
                                                channel.Policy2ID.Int64,
×
1386
                                        )
×
1387
                                }
×
1388
                        }
1389

1390
                        // Batch load channel data for all unique channels.
1391
                        channelBatchData, err := batchLoadChannelData(
×
1392
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1393
                                allPolicyIDs,
×
1394
                        )
×
1395
                        if err != nil {
×
1396
                                return nil, fmt.Errorf("unable to batch "+
×
1397
                                        "load channel data: %w", err)
×
1398
                        }
×
1399

1400
                        // Create map of node ID to channels that involve this
1401
                        // node.
1402
                        nodeIDSet := make(map[int64]bool)
×
1403
                        for _, nodeID := range nodeIDs {
×
1404
                                nodeIDSet[nodeID] = true
×
1405
                        }
×
1406

1407
                        nodeChannelMap := make(
×
1408
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1409
                        )
×
1410
                        for _, channel := range uniqueChannels {
×
1411
                                // Add channel to both nodes if they're in our
×
1412
                                // current page.
×
1413
                                node1 := channel.GraphChannel.NodeID1
×
1414
                                if nodeIDSet[node1] {
×
1415
                                        nodeChannelMap[node1] = append(
×
1416
                                                nodeChannelMap[node1], channel,
×
1417
                                        )
×
1418
                                }
×
1419
                                node2 := channel.GraphChannel.NodeID2
×
1420
                                if nodeIDSet[node2] {
×
1421
                                        nodeChannelMap[node2] = append(
×
1422
                                                nodeChannelMap[node2], channel,
×
1423
                                        )
×
1424
                                }
×
1425
                        }
1426

1427
                        return &nodeCachedBatchData{
×
1428
                                features:      nodeFeatures,
×
1429
                                addrs:         nodeAddrs,
×
1430
                                chanBatchData: channelBatchData,
×
1431
                                chanMap:       nodeChannelMap,
×
1432
                        }, nil
×
1433
                }
1434

1435
                // processItem is used to process each node in the current page.
1436
                processItem := func(ctx context.Context,
×
1437
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1438
                        batchData *nodeCachedBatchData) error {
×
1439

×
1440
                        // Build feature vector for this node.
×
1441
                        fv := lnwire.EmptyFeatureVector()
×
1442
                        features, exists := batchData.features[nodeData.ID]
×
1443
                        if exists {
×
1444
                                for _, bit := range features {
×
1445
                                        fv.Set(lnwire.FeatureBit(bit))
×
1446
                                }
×
1447
                        }
1448

1449
                        var nodePub route.Vertex
×
1450
                        copy(nodePub[:], nodeData.PubKey)
×
1451

×
1452
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1453

×
1454
                        toNodeCallback := func() route.Vertex {
×
1455
                                return nodePub
×
1456
                        }
×
1457

1458
                        // Build cached channels map for this node.
1459
                        channels := make(map[uint64]*DirectedChannel)
×
1460
                        for _, channelRow := range nodeChannels {
×
1461
                                directedChan, err := buildDirectedChannel(
×
1462
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1463
                                        channelRow, batchData.chanBatchData, fv,
×
1464
                                        toNodeCallback,
×
1465
                                )
×
1466
                                if err != nil {
×
1467
                                        return err
×
1468
                                }
×
1469

1470
                                channels[directedChan.ChannelID] = directedChan
×
1471
                        }
1472

1473
                        addrs, err := buildNodeAddresses(
×
1474
                                batchData.addrs[nodeData.ID],
×
1475
                        )
×
1476
                        if err != nil {
×
1477
                                return fmt.Errorf("unable to build node "+
×
1478
                                        "addresses: %w", err)
×
1479
                        }
×
1480

1481
                        return cb(ctx, nodePub, addrs, channels)
×
1482
                }
1483

1484
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1485
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1486
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1487
                                return node.ID
×
1488
                        },
×
1489
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1490
                                error) {
×
1491

×
1492
                                return node.ID, nil
×
1493
                        },
×
1494
                        batchDataFunc, processItem,
1495
                )
1496
        }, reset)
1497
}
1498

1499
// ForEachChannelCacheable iterates through all the channel edges stored
1500
// within the graph and invokes the passed callback for each edge. The
1501
// callback takes two edges as since this is a directed graph, both the
1502
// in/out edges are visited. If the callback returns an error, then the
1503
// transaction is aborted and the iteration stops early.
1504
//
1505
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1506
// pointer for that particular channel edge routing policy will be
1507
// passed into the callback.
1508
//
1509
// NOTE: this method is like ForEachChannel but fetches only the data
1510
// required for the graph cache.
1511
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1512
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1513
        reset func()) error {
×
1514

×
1515
        ctx := context.TODO()
×
1516

×
1517
        handleChannel := func(_ context.Context,
×
1518
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1519

×
1520
                node1, node2, err := buildNodeVertices(
×
1521
                        row.Node1Pubkey, row.Node2Pubkey,
×
1522
                )
×
1523
                if err != nil {
×
1524
                        return err
×
1525
                }
×
1526

1527
                edge := buildCacheableChannelInfo(
×
1528
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1529
                )
×
1530

×
1531
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1532
                if err != nil {
×
1533
                        return err
×
1534
                }
×
1535

1536
                pol1, pol2, err := buildCachedChanPolicies(
×
1537
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1538
                )
×
1539
                if err != nil {
×
1540
                        return err
×
1541
                }
×
1542

1543
                return cb(edge, pol1, pol2)
×
1544
        }
1545

1546
        extractCursor := func(
×
1547
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1548

×
1549
                return row.ID
×
1550
        }
×
1551

1552
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1553
                //nolint:ll
×
1554
                queryFunc := func(ctx context.Context, lastID int64,
×
1555
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1556
                        error) {
×
1557

×
1558
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1559
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1560
                                        Version: int16(lnwire.GossipVersion1),
×
1561
                                        ID:      lastID,
×
1562
                                        Limit:   limit,
×
1563
                                },
×
1564
                        )
×
1565
                }
×
1566

1567
                return sqldb.ExecutePaginatedQuery(
×
1568
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1569
                        extractCursor, handleChannel,
×
1570
                )
×
1571
        }, reset)
1572
}
1573

1574
// ForEachChannel iterates through all the channel edges stored within the
1575
// graph and invokes the passed callback for each edge. The callback takes two
1576
// edges as since this is a directed graph, both the in/out edges are visited.
1577
// If the callback returns an error, then the transaction is aborted and the
1578
// iteration stops early.
1579
//
1580
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1581
// for that particular channel edge routing policy will be passed into the
1582
// callback.
1583
//
1584
// NOTE: part of the V1Store interface.
1585
func (s *SQLStore) ForEachChannel(ctx context.Context,
1586
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1587
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1588

×
1589
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1590
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1591
        }, reset)
×
1592
}
1593

1594
// FilterChannelRange returns the channel ID's of all known channels which were
1595
// mined in a block height within the passed range. The channel IDs are grouped
1596
// by their common block height. This method can be used to quickly share with a
1597
// peer the set of channels we know of within a particular range to catch them
1598
// up after a period of time offline. If withTimestamps is true then the
1599
// timestamp info of the latest received channel update messages of the channel
1600
// will be included in the response.
1601
//
1602
// NOTE: This is part of the V1Store interface.
1603
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1604
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1605

×
1606
        var (
×
1607
                ctx       = context.TODO()
×
1608
                startSCID = &lnwire.ShortChannelID{
×
1609
                        BlockHeight: startHeight,
×
1610
                }
×
1611
                endSCID = lnwire.ShortChannelID{
×
1612
                        BlockHeight: endHeight,
×
1613
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1614
                        TxPosition:  math.MaxUint16,
×
1615
                }
×
1616
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1617
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1618
        )
×
1619

×
1620
        // 1) get all channels where channelID is between start and end chan ID.
×
1621
        // 2) skip if not public (ie, no channel_proof)
×
1622
        // 3) collect that channel.
×
1623
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1624
        //    and add those timestamps to the collected channel.
×
1625
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1626
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1627
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1628
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1629
                                StartScid: chanIDStart,
×
1630
                                EndScid:   chanIDEnd,
×
1631
                        },
×
1632
                )
×
1633
                if err != nil {
×
1634
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1635
                                err)
×
1636
                }
×
1637

1638
                for _, dbChan := range dbChans {
×
1639
                        cid := lnwire.NewShortChanIDFromInt(
×
1640
                                byteOrder.Uint64(dbChan.Scid),
×
1641
                        )
×
1642
                        chanInfo := NewChannelUpdateInfo(
×
1643
                                cid, time.Time{}, time.Time{},
×
1644
                        )
×
1645

×
1646
                        if !withTimestamps {
×
1647
                                channelsPerBlock[cid.BlockHeight] = append(
×
1648
                                        channelsPerBlock[cid.BlockHeight],
×
1649
                                        chanInfo,
×
1650
                                )
×
1651

×
1652
                                continue
×
1653
                        }
1654

1655
                        //nolint:ll
1656
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1657
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1658
                                        Version:   int16(lnwire.GossipVersion1),
×
1659
                                        ChannelID: dbChan.ID,
×
1660
                                        NodeID:    dbChan.NodeID1,
×
1661
                                },
×
1662
                        )
×
1663
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1664
                                return fmt.Errorf("unable to fetch node1 "+
×
1665
                                        "policy: %w", err)
×
1666
                        } else if err == nil {
×
1667
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1668
                                        node1Policy.LastUpdate.Int64, 0,
×
1669
                                )
×
1670
                        }
×
1671

1672
                        //nolint:ll
1673
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1674
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1675
                                        Version:   int16(lnwire.GossipVersion1),
×
1676
                                        ChannelID: dbChan.ID,
×
1677
                                        NodeID:    dbChan.NodeID2,
×
1678
                                },
×
1679
                        )
×
1680
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1681
                                return fmt.Errorf("unable to fetch node2 "+
×
1682
                                        "policy: %w", err)
×
1683
                        } else if err == nil {
×
1684
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1685
                                        node2Policy.LastUpdate.Int64, 0,
×
1686
                                )
×
1687
                        }
×
1688

1689
                        channelsPerBlock[cid.BlockHeight] = append(
×
1690
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1691
                        )
×
1692
                }
1693

1694
                return nil
×
1695
        }, func() {
×
1696
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1697
        })
×
1698
        if err != nil {
×
1699
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1700
        }
×
1701

1702
        if len(channelsPerBlock) == 0 {
×
1703
                return nil, nil
×
1704
        }
×
1705

1706
        // Return the channel ranges in ascending block height order.
1707
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1708
        slices.Sort(blocks)
×
1709

×
1710
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1711
                return BlockChannelRange{
×
1712
                        Height:   block,
×
1713
                        Channels: channelsPerBlock[block],
×
1714
                }
×
1715
        }), nil
×
1716
}
1717

1718
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1719
// zombie. This method is used on an ad-hoc basis, when channels need to be
1720
// marked as zombies outside the normal pruning cycle.
1721
//
1722
// NOTE: part of the V1Store interface.
1723
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1724
        pubKey1, pubKey2 [33]byte) error {
×
1725

×
1726
        ctx := context.TODO()
×
1727

×
1728
        s.cacheMu.Lock()
×
1729
        defer s.cacheMu.Unlock()
×
1730

×
1731
        chanIDB := channelIDToBytes(chanID)
×
1732

×
1733
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1734
                return db.UpsertZombieChannel(
×
1735
                        ctx, sqlc.UpsertZombieChannelParams{
×
1736
                                Version:  int16(lnwire.GossipVersion1),
×
1737
                                Scid:     chanIDB,
×
1738
                                NodeKey1: pubKey1[:],
×
1739
                                NodeKey2: pubKey2[:],
×
1740
                        },
×
1741
                )
×
1742
        }, sqldb.NoOpReset)
×
1743
        if err != nil {
×
1744
                return fmt.Errorf("unable to upsert zombie channel "+
×
1745
                        "(channel_id=%d): %w", chanID, err)
×
1746
        }
×
1747

1748
        s.rejectCache.remove(chanID)
×
1749
        s.chanCache.remove(chanID)
×
1750

×
1751
        return nil
×
1752
}
1753

1754
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1755
//
1756
// NOTE: part of the V1Store interface.
1757
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1758
        s.cacheMu.Lock()
×
1759
        defer s.cacheMu.Unlock()
×
1760

×
1761
        var (
×
1762
                ctx     = context.TODO()
×
1763
                chanIDB = channelIDToBytes(chanID)
×
1764
        )
×
1765

×
1766
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1767
                res, err := db.DeleteZombieChannel(
×
1768
                        ctx, sqlc.DeleteZombieChannelParams{
×
1769
                                Scid:    chanIDB,
×
1770
                                Version: int16(lnwire.GossipVersion1),
×
1771
                        },
×
1772
                )
×
1773
                if err != nil {
×
1774
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1775
                                err)
×
1776
                }
×
1777

1778
                rows, err := res.RowsAffected()
×
1779
                if err != nil {
×
1780
                        return err
×
1781
                }
×
1782

1783
                if rows == 0 {
×
1784
                        return ErrZombieEdgeNotFound
×
1785
                } else if rows > 1 {
×
1786
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1787
                                "expected 1", rows)
×
1788
                }
×
1789

1790
                return nil
×
1791
        }, sqldb.NoOpReset)
1792
        if err != nil {
×
1793
                return fmt.Errorf("unable to mark edge live "+
×
1794
                        "(channel_id=%d): %w", chanID, err)
×
1795
        }
×
1796

1797
        s.rejectCache.remove(chanID)
×
1798
        s.chanCache.remove(chanID)
×
1799

×
1800
        return err
×
1801
}
1802

1803
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1804
// zombie, then the two node public keys corresponding to this edge are also
1805
// returned.
1806
//
1807
// NOTE: part of the V1Store interface.
1808
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1809
        error) {
×
1810

×
1811
        var (
×
1812
                ctx              = context.TODO()
×
1813
                isZombie         bool
×
1814
                pubKey1, pubKey2 route.Vertex
×
1815
                chanIDB          = channelIDToBytes(chanID)
×
1816
        )
×
1817

×
1818
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1819
                zombie, err := db.GetZombieChannel(
×
1820
                        ctx, sqlc.GetZombieChannelParams{
×
1821
                                Scid:    chanIDB,
×
1822
                                Version: int16(lnwire.GossipVersion1),
×
1823
                        },
×
1824
                )
×
1825
                if errors.Is(err, sql.ErrNoRows) {
×
1826
                        return nil
×
1827
                }
×
1828
                if err != nil {
×
1829
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1830
                                err)
×
1831
                }
×
1832

1833
                copy(pubKey1[:], zombie.NodeKey1)
×
1834
                copy(pubKey2[:], zombie.NodeKey2)
×
1835
                isZombie = true
×
1836

×
1837
                return nil
×
1838
        }, sqldb.NoOpReset)
1839
        if err != nil {
×
1840
                return false, route.Vertex{}, route.Vertex{},
×
1841
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1842
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1843
        }
×
1844

1845
        return isZombie, pubKey1, pubKey2, nil
×
1846
}
1847

1848
// NumZombies returns the current number of zombie channels in the graph.
1849
//
1850
// NOTE: part of the V1Store interface.
1851
func (s *SQLStore) NumZombies() (uint64, error) {
×
1852
        var (
×
1853
                ctx        = context.TODO()
×
1854
                numZombies uint64
×
1855
        )
×
1856
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1857
                count, err := db.CountZombieChannels(
×
1858
                        ctx, int16(lnwire.GossipVersion1),
×
1859
                )
×
1860
                if err != nil {
×
1861
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1862
                                err)
×
1863
                }
×
1864

1865
                numZombies = uint64(count)
×
1866

×
1867
                return nil
×
1868
        }, sqldb.NoOpReset)
1869
        if err != nil {
×
1870
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1871
        }
×
1872

1873
        return numZombies, nil
×
1874
}
1875

1876
// DeleteChannelEdges removes edges with the given channel IDs from the
1877
// database and marks them as zombies. This ensures that we're unable to re-add
1878
// it to our database once again. If an edge does not exist within the
1879
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1880
// true, then when we mark these edges as zombies, we'll set up the keys such
1881
// that we require the node that failed to send the fresh update to be the one
1882
// that resurrects the channel from its zombie state. The markZombie bool
1883
// denotes whether to mark the channel as a zombie.
1884
//
1885
// NOTE: part of the V1Store interface.
1886
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1887
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1888

×
1889
        s.cacheMu.Lock()
×
1890
        defer s.cacheMu.Unlock()
×
1891

×
1892
        // Keep track of which channels we end up finding so that we can
×
1893
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1894
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1895
        for _, chanID := range chanIDs {
×
1896
                chanLookup[chanID] = struct{}{}
×
1897
        }
×
1898

1899
        var (
×
1900
                ctx   = context.TODO()
×
1901
                edges []*models.ChannelEdgeInfo
×
1902
        )
×
1903
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1904
                // First, collect all channel rows.
×
1905
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1906
                chanCallBack := func(ctx context.Context,
×
1907
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1908

×
1909
                        // Deleting the entry from the map indicates that we
×
1910
                        // have found the channel.
×
1911
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1912
                        delete(chanLookup, scid)
×
1913

×
1914
                        channelRows = append(channelRows, row)
×
1915

×
1916
                        return nil
×
1917
                }
×
1918

1919
                err := s.forEachChanWithPoliciesInSCIDList(
×
1920
                        ctx, db, chanCallBack, chanIDs,
×
1921
                )
×
1922
                if err != nil {
×
1923
                        return err
×
1924
                }
×
1925

1926
                if len(chanLookup) > 0 {
×
1927
                        return ErrEdgeNotFound
×
1928
                }
×
1929

1930
                if len(channelRows) == 0 {
×
1931
                        return nil
×
1932
                }
×
1933

1934
                // Batch build all channel edges.
1935
                var chanIDsToDelete []int64
×
1936
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1937
                        ctx, s.cfg, db, channelRows,
×
1938
                )
×
1939
                if err != nil {
×
1940
                        return err
×
1941
                }
×
1942

1943
                if markZombie {
×
1944
                        for i, row := range channelRows {
×
1945
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1946

×
1947
                                err := handleZombieMarking(
×
1948
                                        ctx, db, row, edges[i],
×
1949
                                        strictZombiePruning, scid,
×
1950
                                )
×
1951
                                if err != nil {
×
1952
                                        return fmt.Errorf("unable to mark "+
×
1953
                                                "channel as zombie: %w", err)
×
1954
                                }
×
1955
                        }
1956
                }
1957

1958
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1959
        }, func() {
×
1960
                edges = nil
×
1961

×
1962
                // Re-fill the lookup map.
×
1963
                for _, chanID := range chanIDs {
×
1964
                        chanLookup[chanID] = struct{}{}
×
1965
                }
×
1966
        })
1967
        if err != nil {
×
1968
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1969
                        err)
×
1970
        }
×
1971

1972
        for _, chanID := range chanIDs {
×
1973
                s.rejectCache.remove(chanID)
×
1974
                s.chanCache.remove(chanID)
×
1975
        }
×
1976

1977
        return edges, nil
×
1978
}
1979

1980
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1981
// channel identified by the channel ID. If the channel can't be found, then
1982
// ErrEdgeNotFound is returned. A struct which houses the general information
1983
// for the channel itself is returned as well as two structs that contain the
1984
// routing policies for the channel in either direction.
1985
//
1986
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1987
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1988
// the ChannelEdgeInfo will only include the public keys of each node.
1989
//
1990
// NOTE: part of the V1Store interface.
1991
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1992
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1993
        *models.ChannelEdgePolicy, error) {
×
1994

×
1995
        var (
×
1996
                ctx              = context.TODO()
×
1997
                edge             *models.ChannelEdgeInfo
×
1998
                policy1, policy2 *models.ChannelEdgePolicy
×
1999
                chanIDB          = channelIDToBytes(chanID)
×
2000
        )
×
2001
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2002
                row, err := db.GetChannelBySCIDWithPolicies(
×
2003
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2004
                                Scid:    chanIDB,
×
2005
                                Version: int16(lnwire.GossipVersion1),
×
2006
                        },
×
2007
                )
×
2008
                if errors.Is(err, sql.ErrNoRows) {
×
2009
                        // First check if this edge is perhaps in the zombie
×
2010
                        // index.
×
2011
                        zombie, err := db.GetZombieChannel(
×
2012
                                ctx, sqlc.GetZombieChannelParams{
×
2013
                                        Scid:    chanIDB,
×
2014
                                        Version: int16(lnwire.GossipVersion1),
×
2015
                                },
×
2016
                        )
×
2017
                        if errors.Is(err, sql.ErrNoRows) {
×
2018
                                return ErrEdgeNotFound
×
2019
                        } else if err != nil {
×
2020
                                return fmt.Errorf("unable to check if "+
×
2021
                                        "channel is zombie: %w", err)
×
2022
                        }
×
2023

2024
                        // At this point, we know the channel is a zombie, so
2025
                        // we'll return an error indicating this, and we will
2026
                        // populate the edge info with the public keys of each
2027
                        // party as this is the only information we have about
2028
                        // it.
2029
                        edge = &models.ChannelEdgeInfo{}
×
2030
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
2031
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
2032

×
2033
                        return ErrZombieEdge
×
2034
                } else if err != nil {
×
2035
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2036
                }
×
2037

2038
                node1, node2, err := buildNodeVertices(
×
2039
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2040
                )
×
2041
                if err != nil {
×
2042
                        return err
×
2043
                }
×
2044

2045
                edge, err = getAndBuildEdgeInfo(
×
2046
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2047
                )
×
2048
                if err != nil {
×
2049
                        return fmt.Errorf("unable to build channel info: %w",
×
2050
                                err)
×
2051
                }
×
2052

2053
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2054
                if err != nil {
×
2055
                        return fmt.Errorf("unable to extract channel "+
×
2056
                                "policies: %w", err)
×
2057
                }
×
2058

2059
                policy1, policy2, err = getAndBuildChanPolicies(
×
2060
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2061
                        node1, node2,
×
2062
                )
×
2063
                if err != nil {
×
2064
                        return fmt.Errorf("unable to build channel "+
×
2065
                                "policies: %w", err)
×
2066
                }
×
2067

2068
                return nil
×
2069
        }, sqldb.NoOpReset)
2070
        if err != nil {
×
2071
                // If we are returning the ErrZombieEdge, then we also need to
×
2072
                // return the edge info as the method comment indicates that
×
2073
                // this will be populated when the edge is a zombie.
×
2074
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2075
                        err)
×
2076
        }
×
2077

2078
        return edge, policy1, policy2, nil
×
2079
}
2080

2081
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2082
// the channel identified by the funding outpoint. If the channel can't be
2083
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2084
// information for the channel itself is returned as well as two structs that
2085
// contain the routing policies for the channel in either direction.
2086
//
2087
// NOTE: part of the V1Store interface.
2088
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2089
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2090
        *models.ChannelEdgePolicy, error) {
×
2091

×
2092
        var (
×
2093
                ctx              = context.TODO()
×
2094
                edge             *models.ChannelEdgeInfo
×
2095
                policy1, policy2 *models.ChannelEdgePolicy
×
2096
        )
×
2097
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2098
                row, err := db.GetChannelByOutpointWithPolicies(
×
2099
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2100
                                Outpoint: op.String(),
×
2101
                                Version:  int16(lnwire.GossipVersion1),
×
2102
                        },
×
2103
                )
×
2104
                if errors.Is(err, sql.ErrNoRows) {
×
2105
                        return ErrEdgeNotFound
×
2106
                } else if err != nil {
×
2107
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2108
                }
×
2109

2110
                node1, node2, err := buildNodeVertices(
×
2111
                        row.Node1Pubkey, row.Node2Pubkey,
×
2112
                )
×
2113
                if err != nil {
×
2114
                        return err
×
2115
                }
×
2116

2117
                edge, err = getAndBuildEdgeInfo(
×
2118
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2119
                )
×
2120
                if err != nil {
×
2121
                        return fmt.Errorf("unable to build channel info: %w",
×
2122
                                err)
×
2123
                }
×
2124

2125
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2126
                if err != nil {
×
2127
                        return fmt.Errorf("unable to extract channel "+
×
2128
                                "policies: %w", err)
×
2129
                }
×
2130

2131
                policy1, policy2, err = getAndBuildChanPolicies(
×
2132
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2133
                        node1, node2,
×
2134
                )
×
2135
                if err != nil {
×
2136
                        return fmt.Errorf("unable to build channel "+
×
2137
                                "policies: %w", err)
×
2138
                }
×
2139

2140
                return nil
×
2141
        }, sqldb.NoOpReset)
2142
        if err != nil {
×
2143
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2144
                        err)
×
2145
        }
×
2146

2147
        return edge, policy1, policy2, nil
×
2148
}
2149

2150
// HasChannelEdge returns true if the database knows of a channel edge with the
2151
// passed channel ID, and false otherwise. If an edge with that ID is found
2152
// within the graph, then two time stamps representing the last time the edge
2153
// was updated for both directed edges are returned along with the boolean. If
2154
// it is not found, then the zombie index is checked and its result is returned
2155
// as the second boolean.
2156
//
2157
// NOTE: part of the V1Store interface.
2158
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2159
        bool, error) {
×
2160

×
2161
        ctx := context.TODO()
×
2162

×
2163
        var (
×
2164
                exists          bool
×
2165
                isZombie        bool
×
2166
                node1LastUpdate time.Time
×
2167
                node2LastUpdate time.Time
×
2168
        )
×
2169

×
2170
        // We'll query the cache with the shared lock held to allow multiple
×
2171
        // readers to access values in the cache concurrently if they exist.
×
2172
        s.cacheMu.RLock()
×
2173
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2174
                s.cacheMu.RUnlock()
×
2175
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2176
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2177
                exists, isZombie = entry.flags.unpack()
×
2178

×
2179
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2180
        }
×
2181
        s.cacheMu.RUnlock()
×
2182

×
2183
        s.cacheMu.Lock()
×
2184
        defer s.cacheMu.Unlock()
×
2185

×
2186
        // The item was not found with the shared lock, so we'll acquire the
×
2187
        // exclusive lock and check the cache again in case another method added
×
2188
        // the entry to the cache while no lock was held.
×
2189
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2190
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2191
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2192
                exists, isZombie = entry.flags.unpack()
×
2193

×
2194
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2195
        }
×
2196

2197
        chanIDB := channelIDToBytes(chanID)
×
2198
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2199
                channel, err := db.GetChannelBySCID(
×
2200
                        ctx, sqlc.GetChannelBySCIDParams{
×
2201
                                Scid:    chanIDB,
×
2202
                                Version: int16(lnwire.GossipVersion1),
×
2203
                        },
×
2204
                )
×
2205
                if errors.Is(err, sql.ErrNoRows) {
×
2206
                        // Check if it is a zombie channel.
×
2207
                        isZombie, err = db.IsZombieChannel(
×
2208
                                ctx, sqlc.IsZombieChannelParams{
×
2209
                                        Scid:    chanIDB,
×
2210
                                        Version: int16(lnwire.GossipVersion1),
×
2211
                                },
×
2212
                        )
×
2213
                        if err != nil {
×
2214
                                return fmt.Errorf("could not check if channel "+
×
2215
                                        "is zombie: %w", err)
×
2216
                        }
×
2217

2218
                        return nil
×
2219
                } else if err != nil {
×
2220
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2221
                }
×
2222

2223
                exists = true
×
2224

×
2225
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2226
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2227
                                Version:   int16(lnwire.GossipVersion1),
×
2228
                                ChannelID: channel.ID,
×
2229
                                NodeID:    channel.NodeID1,
×
2230
                        },
×
2231
                )
×
2232
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2233
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2234
                                err)
×
2235
                } else if err == nil {
×
2236
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2237
                }
×
2238

2239
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2240
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2241
                                Version:   int16(lnwire.GossipVersion1),
×
2242
                                ChannelID: channel.ID,
×
2243
                                NodeID:    channel.NodeID2,
×
2244
                        },
×
2245
                )
×
2246
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2247
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2248
                                err)
×
2249
                } else if err == nil {
×
2250
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2251
                }
×
2252

2253
                return nil
×
2254
        }, sqldb.NoOpReset)
2255
        if err != nil {
×
2256
                return time.Time{}, time.Time{}, false, false,
×
2257
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2258
        }
×
2259

2260
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2261
                upd1Time: node1LastUpdate.Unix(),
×
2262
                upd2Time: node2LastUpdate.Unix(),
×
2263
                flags:    packRejectFlags(exists, isZombie),
×
2264
        })
×
2265

×
2266
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2267
}
2268

2269
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2270
// passed channel point (outpoint). If the passed channel doesn't exist within
2271
// the database, then ErrEdgeNotFound is returned.
2272
//
2273
// NOTE: part of the V1Store interface.
2274
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2275
        var (
×
2276
                ctx       = context.TODO()
×
2277
                channelID uint64
×
2278
        )
×
2279
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2280
                chanID, err := db.GetSCIDByOutpoint(
×
2281
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2282
                                Outpoint: chanPoint.String(),
×
2283
                                Version:  int16(lnwire.GossipVersion1),
×
2284
                        },
×
2285
                )
×
2286
                if errors.Is(err, sql.ErrNoRows) {
×
2287
                        return ErrEdgeNotFound
×
2288
                } else if err != nil {
×
2289
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2290
                                err)
×
2291
                }
×
2292

2293
                channelID = byteOrder.Uint64(chanID)
×
2294

×
2295
                return nil
×
2296
        }, sqldb.NoOpReset)
2297
        if err != nil {
×
2298
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2299
        }
×
2300

2301
        return channelID, nil
×
2302
}
2303

2304
// IsPublicNode is a helper method that determines whether the node with the
2305
// given public key is seen as a public node in the graph from the graph's
2306
// source node's point of view.
2307
//
2308
// NOTE: part of the V1Store interface.
2309
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2310
        ctx := context.TODO()
×
2311

×
NEW
2312
        // Check the cache first and return early if there is a hit.
×
NEW
2313
        cached, err := s.publicNodeCache.Get(pubKey)
×
NEW
2314
        if err == nil && cached != nil {
×
NEW
2315
                return true, nil
×
NEW
2316
        }
×
2317

2318
        // Log any error other than NotFound.
NEW
2319
        if err != nil && !errors.Is(err, cache.ErrElementNotFound) {
×
NEW
2320
                log.Warnf("Unable to check cache if node is public: %v", err)
×
NEW
2321
        }
×
2322

2323
        var isPublic bool
×
NEW
2324
        err = s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2325
                var err error
×
2326
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2327

×
2328
                return err
×
2329
        }, sqldb.NoOpReset)
×
2330
        if err != nil {
×
2331
                return false, fmt.Errorf("unable to check if node is "+
×
2332
                        "public: %w", err)
×
2333
        }
×
2334

2335
        // Store the result in cache only if the node is public.
NEW
2336
        if isPublic {
×
NEW
2337
                _, err = s.publicNodeCache.Put(pubKey, &cachedPublicNode{})
×
NEW
2338
                if err != nil {
×
NEW
2339
                        log.Warnf("Unable to store node info in cache: %v", err)
×
NEW
2340
                }
×
2341
        }
2342

UNCOV
2343
        return isPublic, nil
×
2344
}
2345

2346
// FetchChanInfos returns the set of channel edges that correspond to the passed
2347
// channel ID's. If an edge is the query is unknown to the database, it will
2348
// skipped and the result will contain only those edges that exist at the time
2349
// of the query. This can be used to respond to peer queries that are seeking to
2350
// fill in gaps in their view of the channel graph.
2351
//
2352
// NOTE: part of the V1Store interface.
2353
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2354
        var (
×
2355
                ctx   = context.TODO()
×
2356
                edges = make(map[uint64]ChannelEdge)
×
2357
        )
×
2358
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2359
                // First, collect all channel rows.
×
2360
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2361
                chanCallBack := func(ctx context.Context,
×
2362
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2363

×
2364
                        channelRows = append(channelRows, row)
×
2365
                        return nil
×
2366
                }
×
2367

2368
                err := s.forEachChanWithPoliciesInSCIDList(
×
2369
                        ctx, db, chanCallBack, chanIDs,
×
2370
                )
×
2371
                if err != nil {
×
2372
                        return err
×
2373
                }
×
2374

2375
                if len(channelRows) == 0 {
×
2376
                        return nil
×
2377
                }
×
2378

2379
                // Batch build all channel edges.
2380
                chans, err := batchBuildChannelEdges(
×
2381
                        ctx, s.cfg, db, channelRows,
×
2382
                )
×
2383
                if err != nil {
×
2384
                        return fmt.Errorf("unable to build channel edges: %w",
×
2385
                                err)
×
2386
                }
×
2387

2388
                for _, c := range chans {
×
2389
                        edges[c.Info.ChannelID] = c
×
2390
                }
×
2391

2392
                return err
×
2393
        }, func() {
×
2394
                clear(edges)
×
2395
        })
×
2396
        if err != nil {
×
2397
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2398
        }
×
2399

2400
        res := make([]ChannelEdge, 0, len(edges))
×
2401
        for _, chanID := range chanIDs {
×
2402
                edge, ok := edges[chanID]
×
2403
                if !ok {
×
2404
                        continue
×
2405
                }
2406

2407
                res = append(res, edge)
×
2408
        }
2409

2410
        return res, nil
×
2411
}
2412

2413
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2414
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2415
// channels in a paginated manner.
2416
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2417
        db SQLQueries, cb func(ctx context.Context,
2418
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2419
        chanIDs []uint64) error {
×
2420

×
2421
        queryWrapper := func(ctx context.Context,
×
2422
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2423
                error) {
×
2424

×
2425
                return db.GetChannelsBySCIDWithPolicies(
×
2426
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2427
                                Version: int16(lnwire.GossipVersion1),
×
2428
                                Scids:   scids,
×
2429
                        },
×
2430
                )
×
2431
        }
×
2432

2433
        return sqldb.ExecuteBatchQuery(
×
2434
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2435
                cb,
×
2436
        )
×
2437
}
2438

2439
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2440
// ID's that we don't know and are not known zombies of the passed set. In other
2441
// words, we perform a set difference of our set of chan ID's and the ones
2442
// passed in. This method can be used by callers to determine the set of
2443
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2444
// known zombies is also returned.
2445
//
2446
// NOTE: part of the V1Store interface.
2447
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2448
        []ChannelUpdateInfo, error) {
×
2449

×
2450
        var (
×
2451
                ctx          = context.TODO()
×
2452
                newChanIDs   []uint64
×
2453
                knownZombies []ChannelUpdateInfo
×
2454
                infoLookup   = make(
×
2455
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2456
                )
×
2457
        )
×
2458

×
2459
        // We first build a lookup map of the channel ID's to the
×
2460
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2461
        // already know about.
×
2462
        for _, chanInfo := range chansInfo {
×
2463
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2464
        }
×
2465

2466
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2467
                // The call-back function deletes known channels from
×
2468
                // infoLookup, so that we can later check which channels are
×
2469
                // zombies by only looking at the remaining channels in the set.
×
2470
                cb := func(ctx context.Context,
×
2471
                        channel sqlc.GraphChannel) error {
×
2472

×
2473
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2474

×
2475
                        return nil
×
2476
                }
×
2477

2478
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2479
                if err != nil {
×
2480
                        return fmt.Errorf("unable to iterate through "+
×
2481
                                "channels: %w", err)
×
2482
                }
×
2483

2484
                // We want to ensure that we deal with the channels in the
2485
                // same order that they were passed in, so we iterate over the
2486
                // original chansInfo slice and then check if that channel is
2487
                // still in the infoLookup map.
2488
                for _, chanInfo := range chansInfo {
×
2489
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2490
                        if _, ok := infoLookup[channelID]; !ok {
×
2491
                                continue
×
2492
                        }
2493

2494
                        isZombie, err := db.IsZombieChannel(
×
2495
                                ctx, sqlc.IsZombieChannelParams{
×
2496
                                        Scid:    channelIDToBytes(channelID),
×
2497
                                        Version: int16(lnwire.GossipVersion1),
×
2498
                                },
×
2499
                        )
×
2500
                        if err != nil {
×
2501
                                return fmt.Errorf("unable to fetch zombie "+
×
2502
                                        "channel: %w", err)
×
2503
                        }
×
2504

2505
                        if isZombie {
×
2506
                                knownZombies = append(knownZombies, chanInfo)
×
2507

×
2508
                                continue
×
2509
                        }
2510

2511
                        newChanIDs = append(newChanIDs, channelID)
×
2512
                }
2513

2514
                return nil
×
2515
        }, func() {
×
2516
                newChanIDs = nil
×
2517
                knownZombies = nil
×
2518
                // Rebuild the infoLookup map in case of a rollback.
×
2519
                for _, chanInfo := range chansInfo {
×
2520
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2521
                        infoLookup[scid] = chanInfo
×
2522
                }
×
2523
        })
2524
        if err != nil {
×
2525
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2526
        }
×
2527

2528
        return newChanIDs, knownZombies, nil
×
2529
}
2530

2531
// forEachChanInSCIDList is a helper method that executes a paged query
2532
// against the database to fetch all channels that match the passed
2533
// ChannelUpdateInfo slice. The callback function is called for each channel
2534
// that is found.
2535
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2536
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2537
        chansInfo []ChannelUpdateInfo) error {
×
2538

×
2539
        queryWrapper := func(ctx context.Context,
×
2540
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2541

×
2542
                return db.GetChannelsBySCIDs(
×
2543
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2544
                                Version: int16(lnwire.GossipVersion1),
×
2545
                                Scids:   scids,
×
2546
                        },
×
2547
                )
×
2548
        }
×
2549

2550
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2551
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2552

×
2553
                return channelIDToBytes(channelID)
×
2554
        }
×
2555

2556
        return sqldb.ExecuteBatchQuery(
×
2557
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2558
                cb,
×
2559
        )
×
2560
}
2561

2562
// PruneGraphNodes is a garbage collection method which attempts to prune out
2563
// any nodes from the channel graph that are currently unconnected. This ensure
2564
// that we only maintain a graph of reachable nodes. In the event that a pruned
2565
// node gains more channels, it will be re-added back to the graph.
2566
//
2567
// NOTE: this prunes nodes across protocol versions. It will never prune the
2568
// source nodes.
2569
//
2570
// NOTE: part of the V1Store interface.
2571
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2572
        var ctx = context.TODO()
×
2573

×
2574
        var prunedNodes []route.Vertex
×
2575
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2576
                var err error
×
2577
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2578

×
2579
                return err
×
2580
        }, func() {
×
2581
                prunedNodes = nil
×
2582
        })
×
2583
        if err != nil {
×
2584
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2585
        }
×
2586

2587
        return prunedNodes, nil
×
2588
}
2589

2590
// PruneGraph prunes newly closed channels from the channel graph in response
2591
// to a new block being solved on the network. Any transactions which spend the
2592
// funding output of any known channels within he graph will be deleted.
2593
// Additionally, the "prune tip", or the last block which has been used to
2594
// prune the graph is stored so callers can ensure the graph is fully in sync
2595
// with the current UTXO state. A slice of channels that have been closed by
2596
// the target block along with any pruned nodes are returned if the function
2597
// succeeds without error.
2598
//
2599
// NOTE: part of the V1Store interface.
2600
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2601
        blockHash *chainhash.Hash, blockHeight uint32) (
2602
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2603

×
2604
        ctx := context.TODO()
×
2605

×
2606
        s.cacheMu.Lock()
×
2607
        defer s.cacheMu.Unlock()
×
2608

×
2609
        var (
×
2610
                closedChans []*models.ChannelEdgeInfo
×
2611
                prunedNodes []route.Vertex
×
2612
        )
×
2613
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2614
                // First, collect all channel rows that need to be pruned.
×
2615
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2616
                channelCallback := func(ctx context.Context,
×
2617
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2618

×
2619
                        channelRows = append(channelRows, row)
×
2620

×
2621
                        return nil
×
2622
                }
×
2623

2624
                err := s.forEachChanInOutpoints(
×
2625
                        ctx, db, spentOutputs, channelCallback,
×
2626
                )
×
2627
                if err != nil {
×
2628
                        return fmt.Errorf("unable to fetch channels by "+
×
2629
                                "outpoints: %w", err)
×
2630
                }
×
2631

2632
                if len(channelRows) == 0 {
×
2633
                        // There are no channels to prune. So we can exit early
×
2634
                        // after updating the prune log.
×
2635
                        err = db.UpsertPruneLogEntry(
×
2636
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2637
                                        BlockHash:   blockHash[:],
×
2638
                                        BlockHeight: int64(blockHeight),
×
2639
                                },
×
2640
                        )
×
2641
                        if err != nil {
×
2642
                                return fmt.Errorf("unable to insert prune log "+
×
2643
                                        "entry: %w", err)
×
2644
                        }
×
2645

2646
                        return nil
×
2647
                }
2648

2649
                // Batch build all channel edges for pruning.
2650
                var chansToDelete []int64
×
2651
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2652
                        ctx, s.cfg, db, channelRows,
×
2653
                )
×
2654
                if err != nil {
×
2655
                        return err
×
2656
                }
×
2657

2658
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2659
                if err != nil {
×
2660
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2661
                }
×
2662

2663
                err = db.UpsertPruneLogEntry(
×
2664
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2665
                                BlockHash:   blockHash[:],
×
2666
                                BlockHeight: int64(blockHeight),
×
2667
                        },
×
2668
                )
×
2669
                if err != nil {
×
2670
                        return fmt.Errorf("unable to insert prune log "+
×
2671
                                "entry: %w", err)
×
2672
                }
×
2673

2674
                // Now that we've pruned some channels, we'll also prune any
2675
                // nodes that no longer have any channels.
2676
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2677
                if err != nil {
×
2678
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2679
                                err)
×
2680
                }
×
2681

2682
                return nil
×
2683
        }, func() {
×
2684
                prunedNodes = nil
×
2685
                closedChans = nil
×
2686
        })
×
2687
        if err != nil {
×
2688
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2689
        }
×
2690

2691
        for _, channel := range closedChans {
×
2692
                s.rejectCache.remove(channel.ChannelID)
×
2693
                s.chanCache.remove(channel.ChannelID)
×
2694
        }
×
2695

2696
        return closedChans, prunedNodes, nil
×
2697
}
2698

2699
// forEachChanInOutpoints is a helper function that executes a paginated
2700
// query to fetch channels by their outpoints and applies the given call-back
2701
// to each.
2702
//
2703
// NOTE: this fetches channels for all protocol versions.
2704
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2705
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2706
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2707

×
2708
        // Create a wrapper that uses the transaction's db instance to execute
×
2709
        // the query.
×
2710
        queryWrapper := func(ctx context.Context,
×
2711
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2712
                error) {
×
2713

×
2714
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2715
        }
×
2716

2717
        // Define the conversion function from Outpoint to string.
2718
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2719
                return outpoint.String()
×
2720
        }
×
2721

2722
        return sqldb.ExecuteBatchQuery(
×
2723
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2724
                queryWrapper, cb,
×
2725
        )
×
2726
}
2727

2728
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2729
        dbIDs []int64) error {
×
2730

×
2731
        // Create a wrapper that uses the transaction's db instance to execute
×
2732
        // the query.
×
2733
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2734
                return nil, db.DeleteChannels(ctx, ids)
×
2735
        }
×
2736

2737
        idConverter := func(id int64) int64 {
×
2738
                return id
×
2739
        }
×
2740

2741
        return sqldb.ExecuteBatchQuery(
×
2742
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2743
                queryWrapper, func(ctx context.Context, _ any) error {
×
2744
                        return nil
×
2745
                },
×
2746
        )
2747
}
2748

2749
// ChannelView returns the verifiable edge information for each active channel
2750
// within the known channel graph. The set of UTXOs (along with their scripts)
2751
// returned are the ones that need to be watched on chain to detect channel
2752
// closes on the resident blockchain.
2753
//
2754
// NOTE: part of the V1Store interface.
2755
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2756
        var (
×
2757
                ctx        = context.TODO()
×
2758
                edgePoints []EdgePoint
×
2759
        )
×
2760

×
2761
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2762
                handleChannel := func(_ context.Context,
×
2763
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2764

×
2765
                        pkScript, err := genMultiSigP2WSH(
×
2766
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2767
                        )
×
2768
                        if err != nil {
×
2769
                                return err
×
2770
                        }
×
2771

2772
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2773
                        if err != nil {
×
2774
                                return err
×
2775
                        }
×
2776

2777
                        edgePoints = append(edgePoints, EdgePoint{
×
2778
                                FundingPkScript: pkScript,
×
2779
                                OutPoint:        *op,
×
2780
                        })
×
2781

×
2782
                        return nil
×
2783
                }
2784

2785
                queryFunc := func(ctx context.Context, lastID int64,
×
2786
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2787

×
2788
                        return db.ListChannelsPaginated(
×
2789
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2790
                                        Version: int16(lnwire.GossipVersion1),
×
2791
                                        ID:      lastID,
×
2792
                                        Limit:   limit,
×
2793
                                },
×
2794
                        )
×
2795
                }
×
2796

2797
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2798
                        return row.ID
×
2799
                }
×
2800

2801
                return sqldb.ExecutePaginatedQuery(
×
2802
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2803
                        extractCursor, handleChannel,
×
2804
                )
×
2805
        }, func() {
×
2806
                edgePoints = nil
×
2807
        })
×
2808
        if err != nil {
×
2809
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2810
        }
×
2811

2812
        return edgePoints, nil
×
2813
}
2814

2815
// PruneTip returns the block height and hash of the latest block that has been
2816
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2817
// to tell if the graph is currently in sync with the current best known UTXO
2818
// state.
2819
//
2820
// NOTE: part of the V1Store interface.
2821
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2822
        var (
×
2823
                ctx       = context.TODO()
×
2824
                tipHash   chainhash.Hash
×
2825
                tipHeight uint32
×
2826
        )
×
2827
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2828
                pruneTip, err := db.GetPruneTip(ctx)
×
2829
                if errors.Is(err, sql.ErrNoRows) {
×
2830
                        return ErrGraphNeverPruned
×
2831
                } else if err != nil {
×
2832
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2833
                }
×
2834

2835
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2836
                tipHeight = uint32(pruneTip.BlockHeight)
×
2837

×
2838
                return nil
×
2839
        }, sqldb.NoOpReset)
2840
        if err != nil {
×
2841
                return nil, 0, err
×
2842
        }
×
2843

2844
        return &tipHash, tipHeight, nil
×
2845
}
2846

2847
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2848
//
2849
// NOTE: this prunes nodes across protocol versions. It will never prune the
2850
// source nodes.
2851
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2852
        db SQLQueries) ([]route.Vertex, error) {
×
2853

×
2854
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2855
        if err != nil {
×
2856
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2857
                        "nodes: %w", err)
×
2858
        }
×
2859

2860
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2861
        for i, nodeKey := range nodeKeys {
×
2862
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2863
                if err != nil {
×
2864
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2865
                                "from bytes: %w", err)
×
2866
                }
×
2867

2868
                prunedNodes[i] = pub
×
2869
        }
2870

2871
        return prunedNodes, nil
×
2872
}
2873

2874
// DisconnectBlockAtHeight is used to indicate that the block specified
2875
// by the passed height has been disconnected from the main chain. This
2876
// will "rewind" the graph back to the height below, deleting channels
2877
// that are no longer confirmed from the graph. The prune log will be
2878
// set to the last prune height valid for the remaining chain.
2879
// Channels that were removed from the graph resulting from the
2880
// disconnected block are returned.
2881
//
2882
// NOTE: part of the V1Store interface.
2883
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2884
        []*models.ChannelEdgeInfo, error) {
×
2885

×
2886
        ctx := context.TODO()
×
2887

×
2888
        var (
×
2889
                // Every channel having a ShortChannelID starting at 'height'
×
2890
                // will no longer be confirmed.
×
2891
                startShortChanID = lnwire.ShortChannelID{
×
2892
                        BlockHeight: height,
×
2893
                }
×
2894

×
2895
                // Delete everything after this height from the db up until the
×
2896
                // SCID alias range.
×
2897
                endShortChanID = aliasmgr.StartingAlias
×
2898

×
2899
                removedChans []*models.ChannelEdgeInfo
×
2900

×
2901
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2902
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2903
        )
×
2904

×
2905
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2906
                rows, err := db.GetChannelsBySCIDRange(
×
2907
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2908
                                StartScid: chanIDStart,
×
2909
                                EndScid:   chanIDEnd,
×
2910
                        },
×
2911
                )
×
2912
                if err != nil {
×
2913
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2914
                }
×
2915

2916
                if len(rows) == 0 {
×
2917
                        // No channels to disconnect, but still clean up prune
×
2918
                        // log.
×
2919
                        return db.DeletePruneLogEntriesInRange(
×
2920
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2921
                                        StartHeight: int64(height),
×
2922
                                        EndHeight: int64(
×
2923
                                                endShortChanID.BlockHeight,
×
2924
                                        ),
×
2925
                                },
×
2926
                        )
×
2927
                }
×
2928

2929
                // Batch build all channel edges for disconnection.
2930
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2931
                        ctx, s.cfg, db, rows,
×
2932
                )
×
2933
                if err != nil {
×
2934
                        return err
×
2935
                }
×
2936

2937
                removedChans = channelEdges
×
2938

×
2939
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2940
                if err != nil {
×
2941
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2942
                }
×
2943

2944
                return db.DeletePruneLogEntriesInRange(
×
2945
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2946
                                StartHeight: int64(height),
×
2947
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2948
                        },
×
2949
                )
×
2950
        }, func() {
×
2951
                removedChans = nil
×
2952
        })
×
2953
        if err != nil {
×
2954
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2955
                        "height: %w", err)
×
2956
        }
×
2957

2958
        s.cacheMu.Lock()
×
2959
        for _, channel := range removedChans {
×
2960
                s.rejectCache.remove(channel.ChannelID)
×
2961
                s.chanCache.remove(channel.ChannelID)
×
2962
        }
×
2963
        s.cacheMu.Unlock()
×
2964

×
2965
        return removedChans, nil
×
2966
}
2967

2968
// AddEdgeProof sets the proof of an existing edge in the graph database.
2969
//
2970
// NOTE: part of the V1Store interface.
2971
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2972
        proof *models.ChannelAuthProof) error {
×
2973

×
2974
        var (
×
2975
                ctx       = context.TODO()
×
2976
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2977
        )
×
2978

×
2979
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2980
                res, err := db.AddV1ChannelProof(
×
2981
                        ctx, sqlc.AddV1ChannelProofParams{
×
2982
                                Scid:              scidBytes,
×
2983
                                Node1Signature:    proof.NodeSig1Bytes,
×
2984
                                Node2Signature:    proof.NodeSig2Bytes,
×
2985
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2986
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2987
                        },
×
2988
                )
×
2989
                if err != nil {
×
2990
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2991
                }
×
2992

2993
                n, err := res.RowsAffected()
×
2994
                if err != nil {
×
2995
                        return err
×
2996
                }
×
2997

2998
                if n == 0 {
×
2999
                        return fmt.Errorf("no rows affected when adding edge "+
×
3000
                                "proof for SCID %v", scid)
×
3001
                } else if n > 1 {
×
3002
                        return fmt.Errorf("multiple rows affected when adding "+
×
3003
                                "edge proof for SCID %v: %d rows affected",
×
3004
                                scid, n)
×
3005
                }
×
3006

3007
                return nil
×
3008
        }, sqldb.NoOpReset)
3009
        if err != nil {
×
3010
                return fmt.Errorf("unable to add edge proof: %w", err)
×
3011
        }
×
3012

3013
        return nil
×
3014
}
3015

3016
// PutClosedScid stores a SCID for a closed channel in the database. This is so
3017
// that we can ignore channel announcements that we know to be closed without
3018
// having to validate them and fetch a block.
3019
//
3020
// NOTE: part of the V1Store interface.
3021
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
3022
        var (
×
3023
                ctx     = context.TODO()
×
3024
                chanIDB = channelIDToBytes(scid.ToUint64())
×
3025
        )
×
3026

×
3027
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3028
                return db.InsertClosedChannel(ctx, chanIDB)
×
3029
        }, sqldb.NoOpReset)
×
3030
}
3031

3032
// IsClosedScid checks whether a channel identified by the passed in scid is
3033
// closed. This helps avoid having to perform expensive validation checks.
3034
//
3035
// NOTE: part of the V1Store interface.
3036
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
3037
        var (
×
3038
                ctx      = context.TODO()
×
3039
                isClosed bool
×
3040
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
3041
        )
×
3042
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3043
                var err error
×
3044
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3045
                if err != nil {
×
3046
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3047
                                err)
×
3048
                }
×
3049

3050
                return nil
×
3051
        }, sqldb.NoOpReset)
3052
        if err != nil {
×
3053
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3054
                        err)
×
3055
        }
×
3056

3057
        return isClosed, nil
×
3058
}
3059

3060
// GraphSession will provide the call-back with access to a NodeTraverser
3061
// instance which can be used to perform queries against the channel graph.
3062
//
3063
// NOTE: part of the V1Store interface.
3064
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3065
        reset func()) error {
×
3066

×
3067
        var ctx = context.TODO()
×
3068

×
3069
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3070
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3071
        }, reset)
×
3072
}
3073

3074
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3075
// read only transaction for a consistent view of the graph.
3076
type sqlNodeTraverser struct {
3077
        db    SQLQueries
3078
        chain chainhash.Hash
3079
}
3080

3081
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3082
// NodeTraverser interface.
3083
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3084

3085
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3086
func newSQLNodeTraverser(db SQLQueries,
3087
        chain chainhash.Hash) *sqlNodeTraverser {
×
3088

×
3089
        return &sqlNodeTraverser{
×
3090
                db:    db,
×
3091
                chain: chain,
×
3092
        }
×
3093
}
×
3094

3095
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3096
// node.
3097
//
3098
// NOTE: Part of the NodeTraverser interface.
3099
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3100
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3101

×
3102
        ctx := context.TODO()
×
3103

×
3104
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3105
}
×
3106

3107
// FetchNodeFeatures returns the features of the given node. If the node is
3108
// unknown, assume no additional features are supported.
3109
//
3110
// NOTE: Part of the NodeTraverser interface.
3111
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3112
        *lnwire.FeatureVector, error) {
×
3113

×
3114
        ctx := context.TODO()
×
3115

×
3116
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3117
}
×
3118

3119
// forEachNodeDirectedChannel iterates through all channels of a given
3120
// node, executing the passed callback on the directed edge representing the
3121
// channel and its incoming policy. If the node is not found, no error is
3122
// returned.
3123
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3124
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3125

×
3126
        toNodeCallback := func() route.Vertex {
×
3127
                return nodePub
×
3128
        }
×
3129

3130
        dbID, err := db.GetNodeIDByPubKey(
×
3131
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3132
                        Version: int16(lnwire.GossipVersion1),
×
3133
                        PubKey:  nodePub[:],
×
3134
                },
×
3135
        )
×
3136
        if errors.Is(err, sql.ErrNoRows) {
×
3137
                return nil
×
3138
        } else if err != nil {
×
3139
                return fmt.Errorf("unable to fetch node: %w", err)
×
3140
        }
×
3141

3142
        rows, err := db.ListChannelsByNodeID(
×
3143
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3144
                        Version: int16(lnwire.GossipVersion1),
×
3145
                        NodeID1: dbID,
×
3146
                },
×
3147
        )
×
3148
        if err != nil {
×
3149
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3150
        }
×
3151

3152
        // Exit early if there are no channels for this node so we don't
3153
        // do the unnecessary feature fetching.
3154
        if len(rows) == 0 {
×
3155
                return nil
×
3156
        }
×
3157

3158
        features, err := getNodeFeatures(ctx, db, dbID)
×
3159
        if err != nil {
×
3160
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3161
        }
×
3162

3163
        for _, row := range rows {
×
3164
                node1, node2, err := buildNodeVertices(
×
3165
                        row.Node1Pubkey, row.Node2Pubkey,
×
3166
                )
×
3167
                if err != nil {
×
3168
                        return fmt.Errorf("unable to build node vertices: %w",
×
3169
                                err)
×
3170
                }
×
3171

3172
                edge := buildCacheableChannelInfo(
×
3173
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3174
                        node1, node2,
×
3175
                )
×
3176

×
3177
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3178
                if err != nil {
×
3179
                        return err
×
3180
                }
×
3181

3182
                p1, p2, err := buildCachedChanPolicies(
×
3183
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3184
                )
×
3185
                if err != nil {
×
3186
                        return err
×
3187
                }
×
3188

3189
                // Determine the outgoing and incoming policy for this
3190
                // channel and node combo.
3191
                outPolicy, inPolicy := p1, p2
×
3192
                if p1 != nil && node2 == nodePub {
×
3193
                        outPolicy, inPolicy = p2, p1
×
3194
                } else if p2 != nil && node1 != nodePub {
×
3195
                        outPolicy, inPolicy = p2, p1
×
3196
                }
×
3197

3198
                var cachedInPolicy *models.CachedEdgePolicy
×
3199
                if inPolicy != nil {
×
3200
                        cachedInPolicy = inPolicy
×
3201
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3202
                        cachedInPolicy.ToNodeFeatures = features
×
3203
                }
×
3204

3205
                directedChannel := &DirectedChannel{
×
3206
                        ChannelID:    edge.ChannelID,
×
3207
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3208
                        OtherNode:    edge.NodeKey2Bytes,
×
3209
                        Capacity:     edge.Capacity,
×
3210
                        OutPolicySet: outPolicy != nil,
×
3211
                        InPolicy:     cachedInPolicy,
×
3212
                }
×
3213
                if outPolicy != nil {
×
3214
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3215
                                directedChannel.InboundFee = fee
×
3216
                        })
×
3217
                }
3218

3219
                if nodePub == edge.NodeKey2Bytes {
×
3220
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3221
                }
×
3222

3223
                if err := cb(directedChannel); err != nil {
×
3224
                        return err
×
3225
                }
×
3226
        }
3227

3228
        return nil
×
3229
}
3230

3231
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3232
// and executes the provided callback for each node. It does so via pagination
3233
// along with batch loading of the node feature bits.
3234
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3235
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3236
                features *lnwire.FeatureVector) error) error {
×
3237

×
3238
        handleNode := func(_ context.Context,
×
3239
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3240
                featureBits map[int64][]int) error {
×
3241

×
3242
                fv := lnwire.EmptyFeatureVector()
×
3243
                if features, exists := featureBits[dbNode.ID]; exists {
×
3244
                        for _, bit := range features {
×
3245
                                fv.Set(lnwire.FeatureBit(bit))
×
3246
                        }
×
3247
                }
3248

3249
                var pub route.Vertex
×
3250
                copy(pub[:], dbNode.PubKey)
×
3251

×
3252
                return processNode(dbNode.ID, pub, fv)
×
3253
        }
3254

3255
        queryFunc := func(ctx context.Context, lastID int64,
×
3256
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3257

×
3258
                return db.ListNodeIDsAndPubKeys(
×
3259
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3260
                                Version: int16(lnwire.GossipVersion1),
×
3261
                                ID:      lastID,
×
3262
                                Limit:   limit,
×
3263
                        },
×
3264
                )
×
3265
        }
×
3266

3267
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3268
                return row.ID
×
3269
        }
×
3270

3271
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3272
                return node.ID, nil
×
3273
        }
×
3274

3275
        batchQueryFunc := func(ctx context.Context,
×
3276
                nodeIDs []int64) (map[int64][]int, error) {
×
3277

×
3278
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3279
        }
×
3280

3281
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3282
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3283
                batchQueryFunc, handleNode,
×
3284
        )
×
3285
}
3286

3287
// forEachNodeChannel iterates through all channels of a node, executing
3288
// the passed callback on each. The call-back is provided with the channel's
3289
// edge information, the outgoing policy and the incoming policy for the
3290
// channel and node combo.
3291
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3292
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3293
                *models.ChannelEdgePolicy,
3294
                *models.ChannelEdgePolicy) error) error {
×
3295

×
3296
        // Get all the V1 channels for this node.
×
3297
        rows, err := db.ListChannelsByNodeID(
×
3298
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3299
                        Version: int16(lnwire.GossipVersion1),
×
3300
                        NodeID1: id,
×
3301
                },
×
3302
        )
×
3303
        if err != nil {
×
3304
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3305
        }
×
3306

3307
        // Collect all the channel and policy IDs.
3308
        var (
×
3309
                chanIDs   = make([]int64, 0, len(rows))
×
3310
                policyIDs = make([]int64, 0, 2*len(rows))
×
3311
        )
×
3312
        for _, row := range rows {
×
3313
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3314

×
3315
                if row.Policy1ID.Valid {
×
3316
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3317
                }
×
3318
                if row.Policy2ID.Valid {
×
3319
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3320
                }
×
3321
        }
3322

3323
        batchData, err := batchLoadChannelData(
×
3324
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3325
        )
×
3326
        if err != nil {
×
3327
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3328
        }
×
3329

3330
        // Call the call-back for each channel and its known policies.
3331
        for _, row := range rows {
×
3332
                node1, node2, err := buildNodeVertices(
×
3333
                        row.Node1Pubkey, row.Node2Pubkey,
×
3334
                )
×
3335
                if err != nil {
×
3336
                        return fmt.Errorf("unable to build node vertices: %w",
×
3337
                                err)
×
3338
                }
×
3339

3340
                edge, err := buildEdgeInfoWithBatchData(
×
3341
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3342
                        batchData,
×
3343
                )
×
3344
                if err != nil {
×
3345
                        return fmt.Errorf("unable to build channel info: %w",
×
3346
                                err)
×
3347
                }
×
3348

3349
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3350
                if err != nil {
×
3351
                        return fmt.Errorf("unable to extract channel "+
×
3352
                                "policies: %w", err)
×
3353
                }
×
3354

3355
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3356
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3357
                )
×
3358
                if err != nil {
×
3359
                        return fmt.Errorf("unable to build channel "+
×
3360
                                "policies: %w", err)
×
3361
                }
×
3362

3363
                // Determine the outgoing and incoming policy for this
3364
                // channel and node combo.
3365
                p1ToNode := row.GraphChannel.NodeID2
×
3366
                p2ToNode := row.GraphChannel.NodeID1
×
3367
                outPolicy, inPolicy := p1, p2
×
3368
                if (p1 != nil && p1ToNode == id) ||
×
3369
                        (p2 != nil && p2ToNode != id) {
×
3370

×
3371
                        outPolicy, inPolicy = p2, p1
×
3372
                }
×
3373

3374
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3375
                        return err
×
3376
                }
×
3377
        }
3378

3379
        return nil
×
3380
}
3381

3382
// updateChanEdgePolicy upserts the channel policy info we have stored for
3383
// a channel we already know of.
3384
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3385
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3386
        error) {
×
3387

×
3388
        var (
×
3389
                node1Pub, node2Pub route.Vertex
×
3390
                isNode1            bool
×
3391
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3392
        )
×
3393

×
3394
        // Check that this edge policy refers to a channel that we already
×
3395
        // know of. We do this explicitly so that we can return the appropriate
×
3396
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3397
        // abort the transaction which would abort the entire batch.
×
3398
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3399
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3400
                        Scid:    chanIDB,
×
3401
                        Version: int16(lnwire.GossipVersion1),
×
3402
                },
×
3403
        )
×
3404
        if errors.Is(err, sql.ErrNoRows) {
×
3405
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3406
        } else if err != nil {
×
3407
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3408
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3409
        }
×
3410

3411
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3412
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3413

×
3414
        // Figure out which node this edge is from.
×
3415
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3416
        nodeID := dbChan.NodeID1
×
3417
        if !isNode1 {
×
3418
                nodeID = dbChan.NodeID2
×
3419
        }
×
3420

3421
        var (
×
3422
                inboundBase sql.NullInt64
×
3423
                inboundRate sql.NullInt64
×
3424
        )
×
3425
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3426
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3427
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3428
        })
×
3429

3430
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3431
                Version:     int16(lnwire.GossipVersion1),
×
3432
                ChannelID:   dbChan.ID,
×
3433
                NodeID:      nodeID,
×
3434
                Timelock:    int32(edge.TimeLockDelta),
×
3435
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3436
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3437
                MinHtlcMsat: int64(edge.MinHTLC),
×
3438
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3439
                Disabled: sql.NullBool{
×
3440
                        Valid: true,
×
3441
                        Bool:  edge.IsDisabled(),
×
3442
                },
×
3443
                MaxHtlcMsat: sql.NullInt64{
×
3444
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3445
                        Int64: int64(edge.MaxHTLC),
×
3446
                },
×
3447
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3448
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3449
                InboundBaseFeeMsat:      inboundBase,
×
3450
                InboundFeeRateMilliMsat: inboundRate,
×
3451
                Signature:               edge.SigBytes,
×
3452
        })
×
3453
        if err != nil {
×
3454
                return node1Pub, node2Pub, isNode1,
×
3455
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3456
        }
×
3457

3458
        // Convert the flat extra opaque data into a map of TLV types to
3459
        // values.
3460
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3461
        if err != nil {
×
3462
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3463
                        "marshal extra opaque data: %w", err)
×
3464
        }
×
3465

3466
        // Update the channel policy's extra signed fields.
3467
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3468
        if err != nil {
×
3469
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3470
                        "policy extra TLVs: %w", err)
×
3471
        }
×
3472

3473
        return node1Pub, node2Pub, isNode1, nil
×
3474
}
3475

3476
// getNodeByPubKey attempts to look up a target node by its public key.
3477
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3478
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3479

×
3480
        dbNode, err := db.GetNodeByPubKey(
×
3481
                ctx, sqlc.GetNodeByPubKeyParams{
×
3482
                        Version: int16(lnwire.GossipVersion1),
×
3483
                        PubKey:  pubKey[:],
×
3484
                },
×
3485
        )
×
3486
        if errors.Is(err, sql.ErrNoRows) {
×
3487
                return 0, nil, ErrGraphNodeNotFound
×
3488
        } else if err != nil {
×
3489
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3490
        }
×
3491

3492
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3493
        if err != nil {
×
3494
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3495
        }
×
3496

3497
        return dbNode.ID, node, nil
×
3498
}
3499

3500
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3501
// provided parameters.
3502
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3503
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3504

×
3505
        return &models.CachedEdgeInfo{
×
3506
                ChannelID:     byteOrder.Uint64(scid),
×
3507
                NodeKey1Bytes: node1Pub,
×
3508
                NodeKey2Bytes: node2Pub,
×
3509
                Capacity:      btcutil.Amount(capacity),
×
3510
        }
×
3511
}
×
3512

3513
// buildNode constructs a Node instance from the given database node
3514
// record. The node's features, addresses and extra signed fields are also
3515
// fetched from the database and set on the node.
3516
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3517
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3518

×
3519
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3520
        if err != nil {
×
3521
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3522
                        err)
×
3523
        }
×
3524

3525
        return buildNodeWithBatchData(dbNode, data)
×
3526
}
3527

3528
// buildNodeWithBatchData builds a models.Node instance
3529
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3530
// features/addresses/extra fields, then the corresponding fields are expected
3531
// to be present in the batchNodeData.
3532
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3533
        batchData *batchNodeData) (*models.Node, error) {
×
3534

×
3535
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
3536
                return nil, fmt.Errorf("unsupported node version: %d",
×
3537
                        dbNode.Version)
×
3538
        }
×
3539

3540
        var pub [33]byte
×
3541
        copy(pub[:], dbNode.PubKey)
×
3542

×
3543
        node := models.NewV1ShellNode(pub)
×
3544

×
3545
        if len(dbNode.Signature) == 0 {
×
3546
                return node, nil
×
3547
        }
×
3548

3549
        node.AuthSigBytes = dbNode.Signature
×
3550

×
3551
        if dbNode.Alias.Valid {
×
3552
                node.Alias = fn.Some(dbNode.Alias.String)
×
3553
        }
×
3554
        if dbNode.LastUpdate.Valid {
×
3555
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3556
        }
×
3557

3558
        var err error
×
3559
        if dbNode.Color.Valid {
×
3560
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3561
                if err != nil {
×
3562
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3563
                                err)
×
3564
                }
×
3565

3566
                node.Color = fn.Some(nodeColor)
×
3567
        }
3568

3569
        // Use preloaded features.
3570
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3571
                fv := lnwire.EmptyFeatureVector()
×
3572
                for _, bit := range features {
×
3573
                        fv.Set(lnwire.FeatureBit(bit))
×
3574
                }
×
3575
                node.Features = fv
×
3576
        }
3577

3578
        // Use preloaded addresses.
3579
        addresses, exists := batchData.addresses[dbNode.ID]
×
3580
        if exists && len(addresses) > 0 {
×
3581
                node.Addresses, err = buildNodeAddresses(addresses)
×
3582
                if err != nil {
×
3583
                        return nil, fmt.Errorf("unable to build addresses "+
×
3584
                                "for node(%d): %w", dbNode.ID, err)
×
3585
                }
×
3586
        }
3587

3588
        // Use preloaded extra fields.
3589
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3590
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3591
                if err != nil {
×
3592
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3593
                                "signed fields: %w", err)
×
3594
                }
×
3595
                if len(recs) != 0 {
×
3596
                        node.ExtraOpaqueData = recs
×
3597
                }
×
3598
        }
3599

3600
        return node, nil
×
3601
}
3602

3603
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3604
// with the preloaded data, and executes the provided callback for each node.
3605
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3606
        db SQLQueries, nodes []sqlc.GraphNode,
3607
        cb func(dbID int64, node *models.Node) error) error {
×
3608

×
3609
        // Extract node IDs for batch loading.
×
3610
        nodeIDs := make([]int64, len(nodes))
×
3611
        for i, node := range nodes {
×
3612
                nodeIDs[i] = node.ID
×
3613
        }
×
3614

3615
        // Batch load all related data for this page.
3616
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3617
        if err != nil {
×
3618
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3619
        }
×
3620

3621
        for _, dbNode := range nodes {
×
3622
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3623
                if err != nil {
×
3624
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3625
                                dbNode.ID, err)
×
3626
                }
×
3627

3628
                if err := cb(dbNode.ID, node); err != nil {
×
3629
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3630
                                dbNode.ID, err)
×
3631
                }
×
3632
        }
3633

3634
        return nil
×
3635
}
3636

3637
// getNodeFeatures fetches the feature bits and constructs the feature vector
3638
// for a node with the given DB ID.
3639
func getNodeFeatures(ctx context.Context, db SQLQueries,
3640
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3641

×
3642
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3643
        if err != nil {
×
3644
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3645
                        nodeID, err)
×
3646
        }
×
3647

3648
        features := lnwire.EmptyFeatureVector()
×
3649
        for _, feature := range rows {
×
3650
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3651
        }
×
3652

3653
        return features, nil
×
3654
}
3655

3656
// upsertNodeAncillaryData updates the node's features, addresses, and extra
3657
// signed fields. This is common logic shared by upsertNode and
3658
// upsertSourceNode.
3659
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
3660
        nodeID int64, node *models.Node) error {
×
3661

×
3662
        // Update the node's features.
×
3663
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3664
        if err != nil {
×
3665
                return fmt.Errorf("inserting node features: %w", err)
×
3666
        }
×
3667

3668
        // Update the node's addresses.
3669
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3670
        if err != nil {
×
3671
                return fmt.Errorf("inserting node addresses: %w", err)
×
3672
        }
×
3673

3674
        // Convert the flat extra opaque data into a map of TLV types to
3675
        // values.
3676
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3677
        if err != nil {
×
3678
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3679
                        err)
×
3680
        }
×
3681

3682
        // Update the node's extra signed fields.
3683
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3684
        if err != nil {
×
3685
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
3686
        }
×
3687

3688
        return nil
×
3689
}
3690

3691
// populateNodeParams populates the common node parameters from a models.Node.
3692
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
3693
func populateNodeParams(node *models.Node,
3694
        setParams func(lastUpdate sql.NullInt64, alias,
3695
                colorStr sql.NullString, signature []byte)) error {
×
3696

×
3697
        if !node.HaveAnnouncement() {
×
3698
                return nil
×
3699
        }
×
3700

3701
        switch node.Version {
×
3702
        case lnwire.GossipVersion1:
×
3703
                lastUpdate := sqldb.SQLInt64(node.LastUpdate.Unix())
×
3704
                var alias, colorStr sql.NullString
×
3705

×
3706
                node.Color.WhenSome(func(rgba color.RGBA) {
×
3707
                        colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
3708
                })
×
3709
                node.Alias.WhenSome(func(s string) {
×
3710
                        alias = sqldb.SQLStrValid(s)
×
3711
                })
×
3712

3713
                setParams(lastUpdate, alias, colorStr, node.AuthSigBytes)
×
3714

3715
        case lnwire.GossipVersion2:
×
3716
                // No-op for now.
3717

3718
        default:
×
3719
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
3720
        }
3721

3722
        return nil
×
3723
}
3724

3725
// buildNodeUpsertParams builds the parameters for upserting a node using the
3726
// strict UpsertNode query (requires timestamp to be increasing).
3727
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
3728
        params := sqlc.UpsertNodeParams{
×
3729
                Version: int16(lnwire.GossipVersion1),
×
3730
                PubKey:  node.PubKeyBytes[:],
×
3731
        }
×
3732

×
3733
        err := populateNodeParams(
×
3734
                node, func(lastUpdate sql.NullInt64, alias,
×
3735
                        colorStr sql.NullString,
×
3736
                        signature []byte) {
×
3737

×
3738
                        params.LastUpdate = lastUpdate
×
3739
                        params.Alias = alias
×
3740
                        params.Color = colorStr
×
3741
                        params.Signature = signature
×
3742
                })
×
3743

3744
        return params, err
×
3745
}
3746

3747
// buildSourceNodeUpsertParams builds the parameters for upserting the source
3748
// node using the lenient UpsertSourceNode query (allows same timestamp).
3749
func buildSourceNodeUpsertParams(node *models.Node) (
3750
        sqlc.UpsertSourceNodeParams, error) {
×
3751

×
3752
        params := sqlc.UpsertSourceNodeParams{
×
3753
                Version: int16(lnwire.GossipVersion1),
×
3754
                PubKey:  node.PubKeyBytes[:],
×
3755
        }
×
3756

×
3757
        err := populateNodeParams(
×
3758
                node, func(lastUpdate sql.NullInt64, alias,
×
3759
                        colorStr sql.NullString, signature []byte) {
×
3760

×
3761
                        params.LastUpdate = lastUpdate
×
3762
                        params.Alias = alias
×
3763
                        params.Color = colorStr
×
3764
                        params.Signature = signature
×
3765
                },
×
3766
        )
3767

3768
        return params, err
×
3769
}
3770

3771
// upsertSourceNode upserts the source node record into the database using a
3772
// less strict upsert that allows updates even when the timestamp hasn't
3773
// changed. This is necessary to handle concurrent updates to our own node
3774
// during startup and runtime. The node's features, addresses and extra TLV
3775
// types are also updated. The node's DB ID is returned.
3776
func upsertSourceNode(ctx context.Context, db SQLQueries,
3777
        node *models.Node) (int64, error) {
×
3778

×
3779
        params, err := buildSourceNodeUpsertParams(node)
×
3780
        if err != nil {
×
3781
                return 0, err
×
3782
        }
×
3783

3784
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
3785
        if err != nil {
×
3786
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
3787
                        node.PubKeyBytes, err)
×
3788
        }
×
3789

3790
        // We can exit here if we don't have the announcement yet.
3791
        if !node.HaveAnnouncement() {
×
3792
                return nodeID, nil
×
3793
        }
×
3794

3795
        // Update the ancillary node data (features, addresses, extra fields).
3796
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3797
        if err != nil {
×
3798
                return 0, err
×
3799
        }
×
3800

3801
        return nodeID, nil
×
3802
}
3803

3804
// upsertNode upserts the node record into the database. If the node already
3805
// exists, then the node's information is updated. If the node doesn't exist,
3806
// then a new node is created. The node's features, addresses and extra TLV
3807
// types are also updated. The node's DB ID is returned.
3808
func upsertNode(ctx context.Context, db SQLQueries,
3809
        node *models.Node) (int64, error) {
×
3810

×
3811
        params, err := buildNodeUpsertParams(node)
×
3812
        if err != nil {
×
3813
                return 0, err
×
3814
        }
×
3815

3816
        nodeID, err := db.UpsertNode(ctx, params)
×
3817
        if err != nil {
×
3818
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3819
                        err)
×
3820
        }
×
3821

3822
        // We can exit here if we don't have the announcement yet.
3823
        if !node.HaveAnnouncement() {
×
3824
                return nodeID, nil
×
3825
        }
×
3826

3827
        // Update the ancillary node data (features, addresses, extra fields).
3828
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3829
        if err != nil {
×
3830
                return 0, err
×
3831
        }
×
3832

3833
        return nodeID, nil
×
3834
}
3835

3836
// upsertNodeFeatures updates the node's features node_features table. This
3837
// includes deleting any feature bits no longer present and inserting any new
3838
// feature bits. If the feature bit does not yet exist in the features table,
3839
// then an entry is created in that table first.
3840
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3841
        features *lnwire.FeatureVector) error {
×
3842

×
3843
        // Get any existing features for the node.
×
3844
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3845
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3846
                return err
×
3847
        }
×
3848

3849
        // Copy the nodes latest set of feature bits.
3850
        newFeatures := make(map[int32]struct{})
×
3851
        if features != nil {
×
3852
                for feature := range features.Features() {
×
3853
                        newFeatures[int32(feature)] = struct{}{}
×
3854
                }
×
3855
        }
3856

3857
        // For any current feature that already exists in the DB, remove it from
3858
        // the in-memory map. For any existing feature that does not exist in
3859
        // the in-memory map, delete it from the database.
3860
        for _, feature := range existingFeatures {
×
3861
                // The feature is still present, so there are no updates to be
×
3862
                // made.
×
3863
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3864
                        delete(newFeatures, feature.FeatureBit)
×
3865
                        continue
×
3866
                }
3867

3868
                // The feature is no longer present, so we remove it from the
3869
                // database.
3870
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3871
                        NodeID:     nodeID,
×
3872
                        FeatureBit: feature.FeatureBit,
×
3873
                })
×
3874
                if err != nil {
×
3875
                        return fmt.Errorf("unable to delete node(%d) "+
×
3876
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3877
                                err)
×
3878
                }
×
3879
        }
3880

3881
        // Any remaining entries in newFeatures are new features that need to be
3882
        // added to the database for the first time.
3883
        for feature := range newFeatures {
×
3884
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3885
                        NodeID:     nodeID,
×
3886
                        FeatureBit: feature,
×
3887
                })
×
3888
                if err != nil {
×
3889
                        return fmt.Errorf("unable to insert node(%d) "+
×
3890
                                "feature(%v): %w", nodeID, feature, err)
×
3891
                }
×
3892
        }
3893

3894
        return nil
×
3895
}
3896

3897
// fetchNodeFeatures fetches the features for a node with the given public key.
3898
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3899
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3900

×
3901
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3902
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3903
                        PubKey:  nodePub[:],
×
3904
                        Version: int16(lnwire.GossipVersion1),
×
3905
                },
×
3906
        )
×
3907
        if err != nil {
×
3908
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3909
                        nodePub, err)
×
3910
        }
×
3911

3912
        features := lnwire.EmptyFeatureVector()
×
3913
        for _, bit := range rows {
×
3914
                features.Set(lnwire.FeatureBit(bit))
×
3915
        }
×
3916

3917
        return features, nil
×
3918
}
3919

3920
// dbAddressType is an enum type that represents the different address types
3921
// that we store in the node_addresses table. The address type determines how
3922
// the address is to be serialised/deserialize.
3923
type dbAddressType uint8
3924

3925
const (
3926
        addressTypeIPv4   dbAddressType = 1
3927
        addressTypeIPv6   dbAddressType = 2
3928
        addressTypeTorV2  dbAddressType = 3
3929
        addressTypeTorV3  dbAddressType = 4
3930
        addressTypeDNS    dbAddressType = 5
3931
        addressTypeOpaque dbAddressType = math.MaxInt8
3932
)
3933

3934
// collectAddressRecords collects the addresses from the provided
3935
// net.Addr slice and returns a map of dbAddressType to a slice of address
3936
// strings.
3937
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3938
        error) {
×
3939

×
3940
        // Copy the nodes latest set of addresses.
×
3941
        newAddresses := map[dbAddressType][]string{
×
3942
                addressTypeIPv4:   {},
×
3943
                addressTypeIPv6:   {},
×
3944
                addressTypeTorV2:  {},
×
3945
                addressTypeTorV3:  {},
×
3946
                addressTypeDNS:    {},
×
3947
                addressTypeOpaque: {},
×
3948
        }
×
3949
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3950
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3951
        }
×
3952

3953
        for _, address := range addresses {
×
3954
                switch addr := address.(type) {
×
3955
                case *net.TCPAddr:
×
3956
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3957
                                addAddr(addressTypeIPv4, addr)
×
3958
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3959
                                addAddr(addressTypeIPv6, addr)
×
3960
                        } else {
×
3961
                                return nil, fmt.Errorf("unhandled IP "+
×
3962
                                        "address: %v", addr)
×
3963
                        }
×
3964

3965
                case *tor.OnionAddr:
×
3966
                        switch len(addr.OnionService) {
×
3967
                        case tor.V2Len:
×
3968
                                addAddr(addressTypeTorV2, addr)
×
3969
                        case tor.V3Len:
×
3970
                                addAddr(addressTypeTorV3, addr)
×
3971
                        default:
×
3972
                                return nil, fmt.Errorf("invalid length for " +
×
3973
                                        "a tor address")
×
3974
                        }
3975

3976
                case *lnwire.DNSAddress:
×
3977
                        addAddr(addressTypeDNS, addr)
×
3978

3979
                case *lnwire.OpaqueAddrs:
×
3980
                        addAddr(addressTypeOpaque, addr)
×
3981

3982
                default:
×
3983
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3984
                                addr)
×
3985
                }
3986
        }
3987

3988
        return newAddresses, nil
×
3989
}
3990

3991
// upsertNodeAddresses updates the node's addresses in the database. This
3992
// includes deleting any existing addresses and inserting the new set of
3993
// addresses. The deletion is necessary since the ordering of the addresses may
3994
// change, and we need to ensure that the database reflects the latest set of
3995
// addresses so that at the time of reconstructing the node announcement, the
3996
// order is preserved and the signature over the message remains valid.
3997
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3998
        addresses []net.Addr) error {
×
3999

×
4000
        // Delete any existing addresses for the node. This is required since
×
4001
        // even if the new set of addresses is the same, the ordering may have
×
4002
        // changed for a given address type.
×
4003
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
4004
        if err != nil {
×
4005
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
4006
                        nodeID, err)
×
4007
        }
×
4008

4009
        newAddresses, err := collectAddressRecords(addresses)
×
4010
        if err != nil {
×
4011
                return err
×
4012
        }
×
4013

4014
        // Any remaining entries in newAddresses are new addresses that need to
4015
        // be added to the database for the first time.
4016
        for addrType, addrList := range newAddresses {
×
4017
                for position, addr := range addrList {
×
4018
                        err := db.UpsertNodeAddress(
×
4019
                                ctx, sqlc.UpsertNodeAddressParams{
×
4020
                                        NodeID:   nodeID,
×
4021
                                        Type:     int16(addrType),
×
4022
                                        Address:  addr,
×
4023
                                        Position: int32(position),
×
4024
                                },
×
4025
                        )
×
4026
                        if err != nil {
×
4027
                                return fmt.Errorf("unable to insert "+
×
4028
                                        "node(%d) address(%v): %w", nodeID,
×
4029
                                        addr, err)
×
4030
                        }
×
4031
                }
4032
        }
4033

4034
        return nil
×
4035
}
4036

4037
// getNodeAddresses fetches the addresses for a node with the given DB ID.
4038
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
4039
        error) {
×
4040

×
4041
        // GetNodeAddresses ensures that the addresses for a given type are
×
4042
        // returned in the same order as they were inserted.
×
4043
        rows, err := db.GetNodeAddresses(ctx, id)
×
4044
        if err != nil {
×
4045
                return nil, err
×
4046
        }
×
4047

4048
        addresses := make([]net.Addr, 0, len(rows))
×
4049
        for _, row := range rows {
×
4050
                address := row.Address
×
4051

×
4052
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4053
                if err != nil {
×
4054
                        return nil, fmt.Errorf("unable to parse address "+
×
4055
                                "for node(%d): %v: %w", id, address, err)
×
4056
                }
×
4057

4058
                addresses = append(addresses, addr)
×
4059
        }
4060

4061
        // If we have no addresses, then we'll return nil instead of an
4062
        // empty slice.
4063
        if len(addresses) == 0 {
×
4064
                addresses = nil
×
4065
        }
×
4066

4067
        return addresses, nil
×
4068
}
4069

4070
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4071
// database. This includes updating any existing types, inserting any new types,
4072
// and deleting any types that are no longer present.
4073
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4074
        nodeID int64, extraFields map[uint64][]byte) error {
×
4075

×
4076
        // Get any existing extra signed fields for the node.
×
4077
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4078
        if err != nil {
×
4079
                return err
×
4080
        }
×
4081

4082
        // Make a lookup map of the existing field types so that we can use it
4083
        // to keep track of any fields we should delete.
4084
        m := make(map[uint64]bool)
×
4085
        for _, field := range existingFields {
×
4086
                m[uint64(field.Type)] = true
×
4087
        }
×
4088

4089
        // For all the new fields, we'll upsert them and remove them from the
4090
        // map of existing fields.
4091
        for tlvType, value := range extraFields {
×
4092
                err = db.UpsertNodeExtraType(
×
4093
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4094
                                NodeID: nodeID,
×
4095
                                Type:   int64(tlvType),
×
4096
                                Value:  value,
×
4097
                        },
×
4098
                )
×
4099
                if err != nil {
×
4100
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4101
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4102
                }
×
4103

4104
                // Remove the field from the map of existing fields if it was
4105
                // present.
4106
                delete(m, tlvType)
×
4107
        }
4108

4109
        // For all the fields that are left in the map of existing fields, we'll
4110
        // delete them as they are no longer present in the new set of fields.
4111
        for tlvType := range m {
×
4112
                err = db.DeleteExtraNodeType(
×
4113
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4114
                                NodeID: nodeID,
×
4115
                                Type:   int64(tlvType),
×
4116
                        },
×
4117
                )
×
4118
                if err != nil {
×
4119
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4120
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4121
                }
×
4122
        }
4123

4124
        return nil
×
4125
}
4126

4127
// srcNodeInfo holds the information about the source node of the graph.
4128
type srcNodeInfo struct {
4129
        // id is the DB level ID of the source node entry in the "nodes" table.
4130
        id int64
4131

4132
        // pub is the public key of the source node.
4133
        pub route.Vertex
4134
}
4135

4136
// sourceNode returns the DB node ID and pub key of the source node for the
4137
// specified protocol version.
4138
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4139
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4140

×
4141
        s.srcNodeMu.Lock()
×
4142
        defer s.srcNodeMu.Unlock()
×
4143

×
4144
        // If we already have the source node ID and pub key cached, then
×
4145
        // return them.
×
4146
        if info, ok := s.srcNodes[version]; ok {
×
4147
                return info.id, info.pub, nil
×
4148
        }
×
4149

4150
        var pubKey route.Vertex
×
4151

×
4152
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4153
        if err != nil {
×
4154
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4155
                        err)
×
4156
        }
×
4157

4158
        if len(nodes) == 0 {
×
4159
                return 0, pubKey, ErrSourceNodeNotSet
×
4160
        } else if len(nodes) > 1 {
×
4161
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4162
                        "protocol %s found", version)
×
4163
        }
×
4164

4165
        copy(pubKey[:], nodes[0].PubKey)
×
4166

×
4167
        s.srcNodes[version] = &srcNodeInfo{
×
4168
                id:  nodes[0].NodeID,
×
4169
                pub: pubKey,
×
4170
        }
×
4171

×
4172
        return nodes[0].NodeID, pubKey, nil
×
4173
}
4174

4175
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4176
// This then produces a map from TLV type to value. If the input is not a
4177
// valid TLV stream, then an error is returned.
4178
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4179
        r := bytes.NewReader(data)
×
4180

×
4181
        tlvStream, err := tlv.NewStream()
×
4182
        if err != nil {
×
4183
                return nil, err
×
4184
        }
×
4185

4186
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4187
        // pass it into the P2P decoding variant.
4188
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4189
        if err != nil {
×
4190
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4191
        }
×
4192
        if len(parsedTypes) == 0 {
×
4193
                return nil, nil
×
4194
        }
×
4195

4196
        records := make(map[uint64][]byte)
×
4197
        for k, v := range parsedTypes {
×
4198
                records[uint64(k)] = v
×
4199
        }
×
4200

4201
        return records, nil
×
4202
}
4203

4204
// insertChannel inserts a new channel record into the database.
4205
func insertChannel(ctx context.Context, db SQLQueries,
4206
        edge *models.ChannelEdgeInfo) error {
×
4207

×
4208
        // Make sure that at least a "shell" entry for each node is present in
×
4209
        // the nodes table.
×
4210
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4211
        if err != nil {
×
4212
                return fmt.Errorf("unable to create shell node: %w", err)
×
4213
        }
×
4214

4215
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4216
        if err != nil {
×
4217
                return fmt.Errorf("unable to create shell node: %w", err)
×
4218
        }
×
4219

4220
        var capacity sql.NullInt64
×
4221
        if edge.Capacity != 0 {
×
4222
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4223
        }
×
4224

4225
        createParams := sqlc.CreateChannelParams{
×
4226
                Version:     int16(lnwire.GossipVersion1),
×
4227
                Scid:        channelIDToBytes(edge.ChannelID),
×
4228
                NodeID1:     node1DBID,
×
4229
                NodeID2:     node2DBID,
×
4230
                Outpoint:    edge.ChannelPoint.String(),
×
4231
                Capacity:    capacity,
×
4232
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4233
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4234
        }
×
4235

×
4236
        if edge.AuthProof != nil {
×
4237
                proof := edge.AuthProof
×
4238

×
4239
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4240
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4241
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4242
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4243
        }
×
4244

4245
        // Insert the new channel record.
4246
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4247
        if err != nil {
×
4248
                return err
×
4249
        }
×
4250

4251
        // Insert any channel features.
4252
        for feature := range edge.Features.Features() {
×
4253
                err = db.InsertChannelFeature(
×
4254
                        ctx, sqlc.InsertChannelFeatureParams{
×
4255
                                ChannelID:  dbChanID,
×
4256
                                FeatureBit: int32(feature),
×
4257
                        },
×
4258
                )
×
4259
                if err != nil {
×
4260
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4261
                                "feature(%v): %w", dbChanID, feature, err)
×
4262
                }
×
4263
        }
4264

4265
        // Finally, insert any extra TLV fields in the channel announcement.
4266
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4267
        if err != nil {
×
4268
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4269
                        err)
×
4270
        }
×
4271

4272
        for tlvType, value := range extra {
×
4273
                err := db.UpsertChannelExtraType(
×
4274
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4275
                                ChannelID: dbChanID,
×
4276
                                Type:      int64(tlvType),
×
4277
                                Value:     value,
×
4278
                        },
×
4279
                )
×
4280
                if err != nil {
×
4281
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4282
                                "extra signed field(%v): %w", edge.ChannelID,
×
4283
                                tlvType, err)
×
4284
                }
×
4285
        }
4286

4287
        return nil
×
4288
}
4289

4290
// maybeCreateShellNode checks if a shell node entry exists for the
4291
// given public key. If it does not exist, then a new shell node entry is
4292
// created. The ID of the node is returned. A shell node only has a protocol
4293
// version and public key persisted.
4294
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4295
        pubKey route.Vertex) (int64, error) {
×
4296

×
4297
        dbNode, err := db.GetNodeByPubKey(
×
4298
                ctx, sqlc.GetNodeByPubKeyParams{
×
4299
                        PubKey:  pubKey[:],
×
4300
                        Version: int16(lnwire.GossipVersion1),
×
4301
                },
×
4302
        )
×
4303
        // The node exists. Return the ID.
×
4304
        if err == nil {
×
4305
                return dbNode.ID, nil
×
4306
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4307
                return 0, err
×
4308
        }
×
4309

4310
        // Otherwise, the node does not exist, so we create a shell entry for
4311
        // it.
4312
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4313
                Version: int16(lnwire.GossipVersion1),
×
4314
                PubKey:  pubKey[:],
×
4315
        })
×
4316
        if err != nil {
×
4317
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4318
        }
×
4319

4320
        return id, nil
×
4321
}
4322

4323
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4324
// the database. This includes deleting any existing types and then inserting
4325
// the new types.
4326
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4327
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4328

×
4329
        // Delete all existing extra signed fields for the channel policy.
×
4330
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4331
        if err != nil {
×
4332
                return fmt.Errorf("unable to delete "+
×
4333
                        "existing policy extra signed fields for policy %d: %w",
×
4334
                        chanPolicyID, err)
×
4335
        }
×
4336

4337
        // Insert all new extra signed fields for the channel policy.
4338
        for tlvType, value := range extraFields {
×
4339
                err = db.UpsertChanPolicyExtraType(
×
4340
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4341
                                ChannelPolicyID: chanPolicyID,
×
4342
                                Type:            int64(tlvType),
×
4343
                                Value:           value,
×
4344
                        },
×
4345
                )
×
4346
                if err != nil {
×
4347
                        return fmt.Errorf("unable to insert "+
×
4348
                                "channel_policy(%d) extra signed field(%v): %w",
×
4349
                                chanPolicyID, tlvType, err)
×
4350
                }
×
4351
        }
4352

4353
        return nil
×
4354
}
4355

4356
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4357
// provided dbChanRow and also fetches any other required information
4358
// to construct the edge info.
4359
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4360
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4361
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4362

×
4363
        data, err := batchLoadChannelData(
×
4364
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4365
        )
×
4366
        if err != nil {
×
4367
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4368
                        err)
×
4369
        }
×
4370

4371
        return buildEdgeInfoWithBatchData(
×
4372
                cfg.ChainHash, dbChan, node1, node2, data,
×
4373
        )
×
4374
}
4375

4376
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4377
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4378
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4379
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4380

×
4381
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4382
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4383
                        dbChan.Version)
×
4384
        }
×
4385

4386
        // Use pre-loaded features and extras types.
4387
        fv := lnwire.EmptyFeatureVector()
×
4388
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4389
                for _, bit := range features {
×
4390
                        fv.Set(lnwire.FeatureBit(bit))
×
4391
                }
×
4392
        }
4393

4394
        var extras map[uint64][]byte
×
4395
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4396
        if exists {
×
4397
                extras = channelExtras
×
4398
        } else {
×
4399
                extras = make(map[uint64][]byte)
×
4400
        }
×
4401

4402
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4403
        if err != nil {
×
4404
                return nil, err
×
4405
        }
×
4406

4407
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4408
        if err != nil {
×
4409
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4410
                        "fields: %w", err)
×
4411
        }
×
4412
        if recs == nil {
×
4413
                recs = make([]byte, 0)
×
4414
        }
×
4415

4416
        var btcKey1, btcKey2 route.Vertex
×
4417
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4418
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4419

×
4420
        channel := &models.ChannelEdgeInfo{
×
4421
                ChainHash:        chain,
×
4422
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4423
                NodeKey1Bytes:    node1,
×
4424
                NodeKey2Bytes:    node2,
×
4425
                BitcoinKey1Bytes: btcKey1,
×
4426
                BitcoinKey2Bytes: btcKey2,
×
4427
                ChannelPoint:     *op,
×
4428
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4429
                Features:         fv,
×
4430
                ExtraOpaqueData:  recs,
×
4431
        }
×
4432

×
4433
        // We always set all the signatures at the same time, so we can
×
4434
        // safely check if one signature is present to determine if we have the
×
4435
        // rest of the signatures for the auth proof.
×
4436
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4437
                channel.AuthProof = &models.ChannelAuthProof{
×
4438
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4439
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4440
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4441
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4442
                }
×
4443
        }
×
4444

4445
        return channel, nil
×
4446
}
4447

4448
// buildNodeVertices is a helper that converts raw node public keys
4449
// into route.Vertex instances.
4450
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4451
        route.Vertex, error) {
×
4452

×
4453
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4454
        if err != nil {
×
4455
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4456
                        "create vertex from node1 pubkey: %w", err)
×
4457
        }
×
4458

4459
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4460
        if err != nil {
×
4461
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4462
                        "create vertex from node2 pubkey: %w", err)
×
4463
        }
×
4464

4465
        return node1Vertex, node2Vertex, nil
×
4466
}
4467

4468
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4469
// retrieves all the extra info required to build the complete
4470
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4471
// the provided sqlc.GraphChannelPolicy records are nil.
4472
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4473
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4474
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4475
        *models.ChannelEdgePolicy, error) {
×
4476

×
4477
        if dbPol1 == nil && dbPol2 == nil {
×
4478
                return nil, nil, nil
×
4479
        }
×
4480

4481
        var policyIDs = make([]int64, 0, 2)
×
4482
        if dbPol1 != nil {
×
4483
                policyIDs = append(policyIDs, dbPol1.ID)
×
4484
        }
×
4485
        if dbPol2 != nil {
×
4486
                policyIDs = append(policyIDs, dbPol2.ID)
×
4487
        }
×
4488

4489
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4490
        if err != nil {
×
4491
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4492
                        "data: %w", err)
×
4493
        }
×
4494

4495
        pol1, err := buildChanPolicyWithBatchData(
×
4496
                dbPol1, channelID, node2, batchData,
×
4497
        )
×
4498
        if err != nil {
×
4499
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4500
        }
×
4501

4502
        pol2, err := buildChanPolicyWithBatchData(
×
4503
                dbPol2, channelID, node1, batchData,
×
4504
        )
×
4505
        if err != nil {
×
4506
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4507
        }
×
4508

4509
        return pol1, pol2, nil
×
4510
}
4511

4512
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4513
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4514
// then nil is returned for it.
4515
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4516
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4517
        *models.CachedEdgePolicy, error) {
×
4518

×
4519
        var p1, p2 *models.CachedEdgePolicy
×
4520
        if dbPol1 != nil {
×
4521
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4522
                if err != nil {
×
4523
                        return nil, nil, err
×
4524
                }
×
4525

4526
                p1 = models.NewCachedPolicy(policy1)
×
4527
        }
4528
        if dbPol2 != nil {
×
4529
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4530
                if err != nil {
×
4531
                        return nil, nil, err
×
4532
                }
×
4533

4534
                p2 = models.NewCachedPolicy(policy2)
×
4535
        }
4536

4537
        return p1, p2, nil
×
4538
}
4539

4540
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4541
// provided sqlc.GraphChannelPolicy and other required information.
4542
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4543
        extras map[uint64][]byte,
4544
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4545

×
4546
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4547
        if err != nil {
×
4548
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4549
                        "fields: %w", err)
×
4550
        }
×
4551

4552
        var inboundFee fn.Option[lnwire.Fee]
×
4553
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4554
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4555

×
4556
                inboundFee = fn.Some(lnwire.Fee{
×
4557
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4558
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4559
                })
×
4560
        }
×
4561

4562
        return &models.ChannelEdgePolicy{
×
4563
                SigBytes:  dbPolicy.Signature,
×
4564
                ChannelID: channelID,
×
4565
                LastUpdate: time.Unix(
×
4566
                        dbPolicy.LastUpdate.Int64, 0,
×
4567
                ),
×
4568
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4569
                        dbPolicy.MessageFlags,
×
4570
                ),
×
4571
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4572
                        dbPolicy.ChannelFlags,
×
4573
                ),
×
4574
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4575
                MinHTLC: lnwire.MilliSatoshi(
×
4576
                        dbPolicy.MinHtlcMsat,
×
4577
                ),
×
4578
                MaxHTLC: lnwire.MilliSatoshi(
×
4579
                        dbPolicy.MaxHtlcMsat.Int64,
×
4580
                ),
×
4581
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4582
                        dbPolicy.BaseFeeMsat,
×
4583
                ),
×
4584
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4585
                ToNode:                    toNode,
×
4586
                InboundFee:                inboundFee,
×
4587
                ExtraOpaqueData:           recs,
×
4588
        }, nil
×
4589
}
4590

4591
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4592
// row which is expected to be a sqlc type that contains channel policy
4593
// information. It returns two policies, which may be nil if the policy
4594
// information is not present in the row.
4595
//
4596
//nolint:ll,dupl,funlen
4597
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4598
        *sqlc.GraphChannelPolicy, error) {
×
4599

×
4600
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4601
        switch r := row.(type) {
×
4602
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4603
                if r.Policy1Timelock.Valid {
×
4604
                        policy1 = &sqlc.GraphChannelPolicy{
×
4605
                                Timelock:                r.Policy1Timelock.Int32,
×
4606
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4607
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4608
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4609
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4610
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4611
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4612
                                Disabled:                r.Policy1Disabled,
×
4613
                                MessageFlags:            r.Policy1MessageFlags,
×
4614
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4615
                        }
×
4616
                }
×
4617
                if r.Policy2Timelock.Valid {
×
4618
                        policy2 = &sqlc.GraphChannelPolicy{
×
4619
                                Timelock:                r.Policy2Timelock.Int32,
×
4620
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4621
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4622
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4623
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4624
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4625
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4626
                                Disabled:                r.Policy2Disabled,
×
4627
                                MessageFlags:            r.Policy2MessageFlags,
×
4628
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4629
                        }
×
4630
                }
×
4631

4632
                return policy1, policy2, nil
×
4633

4634
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4635
                if r.Policy1ID.Valid {
×
4636
                        policy1 = &sqlc.GraphChannelPolicy{
×
4637
                                ID:                      r.Policy1ID.Int64,
×
4638
                                Version:                 r.Policy1Version.Int16,
×
4639
                                ChannelID:               r.GraphChannel.ID,
×
4640
                                NodeID:                  r.Policy1NodeID.Int64,
×
4641
                                Timelock:                r.Policy1Timelock.Int32,
×
4642
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4643
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4644
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4645
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4646
                                LastUpdate:              r.Policy1LastUpdate,
×
4647
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4648
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4649
                                Disabled:                r.Policy1Disabled,
×
4650
                                MessageFlags:            r.Policy1MessageFlags,
×
4651
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4652
                                Signature:               r.Policy1Signature,
×
4653
                        }
×
4654
                }
×
4655
                if r.Policy2ID.Valid {
×
4656
                        policy2 = &sqlc.GraphChannelPolicy{
×
4657
                                ID:                      r.Policy2ID.Int64,
×
4658
                                Version:                 r.Policy2Version.Int16,
×
4659
                                ChannelID:               r.GraphChannel.ID,
×
4660
                                NodeID:                  r.Policy2NodeID.Int64,
×
4661
                                Timelock:                r.Policy2Timelock.Int32,
×
4662
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4663
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4664
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4665
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4666
                                LastUpdate:              r.Policy2LastUpdate,
×
4667
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4668
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4669
                                Disabled:                r.Policy2Disabled,
×
4670
                                MessageFlags:            r.Policy2MessageFlags,
×
4671
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4672
                                Signature:               r.Policy2Signature,
×
4673
                        }
×
4674
                }
×
4675

4676
                return policy1, policy2, nil
×
4677

4678
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4679
                if r.Policy1ID.Valid {
×
4680
                        policy1 = &sqlc.GraphChannelPolicy{
×
4681
                                ID:                      r.Policy1ID.Int64,
×
4682
                                Version:                 r.Policy1Version.Int16,
×
4683
                                ChannelID:               r.GraphChannel.ID,
×
4684
                                NodeID:                  r.Policy1NodeID.Int64,
×
4685
                                Timelock:                r.Policy1Timelock.Int32,
×
4686
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4687
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4688
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4689
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4690
                                LastUpdate:              r.Policy1LastUpdate,
×
4691
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4692
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4693
                                Disabled:                r.Policy1Disabled,
×
4694
                                MessageFlags:            r.Policy1MessageFlags,
×
4695
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4696
                                Signature:               r.Policy1Signature,
×
4697
                        }
×
4698
                }
×
4699
                if r.Policy2ID.Valid {
×
4700
                        policy2 = &sqlc.GraphChannelPolicy{
×
4701
                                ID:                      r.Policy2ID.Int64,
×
4702
                                Version:                 r.Policy2Version.Int16,
×
4703
                                ChannelID:               r.GraphChannel.ID,
×
4704
                                NodeID:                  r.Policy2NodeID.Int64,
×
4705
                                Timelock:                r.Policy2Timelock.Int32,
×
4706
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4707
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4708
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4709
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4710
                                LastUpdate:              r.Policy2LastUpdate,
×
4711
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4712
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4713
                                Disabled:                r.Policy2Disabled,
×
4714
                                MessageFlags:            r.Policy2MessageFlags,
×
4715
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4716
                                Signature:               r.Policy2Signature,
×
4717
                        }
×
4718
                }
×
4719

4720
                return policy1, policy2, nil
×
4721

4722
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4723
                if r.Policy1ID.Valid {
×
4724
                        policy1 = &sqlc.GraphChannelPolicy{
×
4725
                                ID:                      r.Policy1ID.Int64,
×
4726
                                Version:                 r.Policy1Version.Int16,
×
4727
                                ChannelID:               r.GraphChannel.ID,
×
4728
                                NodeID:                  r.Policy1NodeID.Int64,
×
4729
                                Timelock:                r.Policy1Timelock.Int32,
×
4730
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4731
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4732
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4733
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4734
                                LastUpdate:              r.Policy1LastUpdate,
×
4735
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4736
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4737
                                Disabled:                r.Policy1Disabled,
×
4738
                                MessageFlags:            r.Policy1MessageFlags,
×
4739
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4740
                                Signature:               r.Policy1Signature,
×
4741
                        }
×
4742
                }
×
4743
                if r.Policy2ID.Valid {
×
4744
                        policy2 = &sqlc.GraphChannelPolicy{
×
4745
                                ID:                      r.Policy2ID.Int64,
×
4746
                                Version:                 r.Policy2Version.Int16,
×
4747
                                ChannelID:               r.GraphChannel.ID,
×
4748
                                NodeID:                  r.Policy2NodeID.Int64,
×
4749
                                Timelock:                r.Policy2Timelock.Int32,
×
4750
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4751
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4752
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4753
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4754
                                LastUpdate:              r.Policy2LastUpdate,
×
4755
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4756
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4757
                                Disabled:                r.Policy2Disabled,
×
4758
                                MessageFlags:            r.Policy2MessageFlags,
×
4759
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4760
                                Signature:               r.Policy2Signature,
×
4761
                        }
×
4762
                }
×
4763

4764
                return policy1, policy2, nil
×
4765

4766
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4767
                if r.Policy1ID.Valid {
×
4768
                        policy1 = &sqlc.GraphChannelPolicy{
×
4769
                                ID:                      r.Policy1ID.Int64,
×
4770
                                Version:                 r.Policy1Version.Int16,
×
4771
                                ChannelID:               r.GraphChannel.ID,
×
4772
                                NodeID:                  r.Policy1NodeID.Int64,
×
4773
                                Timelock:                r.Policy1Timelock.Int32,
×
4774
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4775
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4776
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4777
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4778
                                LastUpdate:              r.Policy1LastUpdate,
×
4779
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4780
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4781
                                Disabled:                r.Policy1Disabled,
×
4782
                                MessageFlags:            r.Policy1MessageFlags,
×
4783
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4784
                                Signature:               r.Policy1Signature,
×
4785
                        }
×
4786
                }
×
4787
                if r.Policy2ID.Valid {
×
4788
                        policy2 = &sqlc.GraphChannelPolicy{
×
4789
                                ID:                      r.Policy2ID.Int64,
×
4790
                                Version:                 r.Policy2Version.Int16,
×
4791
                                ChannelID:               r.GraphChannel.ID,
×
4792
                                NodeID:                  r.Policy2NodeID.Int64,
×
4793
                                Timelock:                r.Policy2Timelock.Int32,
×
4794
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4795
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4796
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4797
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4798
                                LastUpdate:              r.Policy2LastUpdate,
×
4799
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4800
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4801
                                Disabled:                r.Policy2Disabled,
×
4802
                                MessageFlags:            r.Policy2MessageFlags,
×
4803
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4804
                                Signature:               r.Policy2Signature,
×
4805
                        }
×
4806
                }
×
4807

4808
                return policy1, policy2, nil
×
4809

4810
        case sqlc.ListChannelsForNodeIDsRow:
×
4811
                if r.Policy1ID.Valid {
×
4812
                        policy1 = &sqlc.GraphChannelPolicy{
×
4813
                                ID:                      r.Policy1ID.Int64,
×
4814
                                Version:                 r.Policy1Version.Int16,
×
4815
                                ChannelID:               r.GraphChannel.ID,
×
4816
                                NodeID:                  r.Policy1NodeID.Int64,
×
4817
                                Timelock:                r.Policy1Timelock.Int32,
×
4818
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4819
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4820
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4821
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4822
                                LastUpdate:              r.Policy1LastUpdate,
×
4823
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4824
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4825
                                Disabled:                r.Policy1Disabled,
×
4826
                                MessageFlags:            r.Policy1MessageFlags,
×
4827
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4828
                                Signature:               r.Policy1Signature,
×
4829
                        }
×
4830
                }
×
4831
                if r.Policy2ID.Valid {
×
4832
                        policy2 = &sqlc.GraphChannelPolicy{
×
4833
                                ID:                      r.Policy2ID.Int64,
×
4834
                                Version:                 r.Policy2Version.Int16,
×
4835
                                ChannelID:               r.GraphChannel.ID,
×
4836
                                NodeID:                  r.Policy2NodeID.Int64,
×
4837
                                Timelock:                r.Policy2Timelock.Int32,
×
4838
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4839
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4840
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4841
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4842
                                LastUpdate:              r.Policy2LastUpdate,
×
4843
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4844
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4845
                                Disabled:                r.Policy2Disabled,
×
4846
                                MessageFlags:            r.Policy2MessageFlags,
×
4847
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4848
                                Signature:               r.Policy2Signature,
×
4849
                        }
×
4850
                }
×
4851

4852
                return policy1, policy2, nil
×
4853

4854
        case sqlc.ListChannelsByNodeIDRow:
×
4855
                if r.Policy1ID.Valid {
×
4856
                        policy1 = &sqlc.GraphChannelPolicy{
×
4857
                                ID:                      r.Policy1ID.Int64,
×
4858
                                Version:                 r.Policy1Version.Int16,
×
4859
                                ChannelID:               r.GraphChannel.ID,
×
4860
                                NodeID:                  r.Policy1NodeID.Int64,
×
4861
                                Timelock:                r.Policy1Timelock.Int32,
×
4862
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4863
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4864
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4865
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4866
                                LastUpdate:              r.Policy1LastUpdate,
×
4867
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4868
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4869
                                Disabled:                r.Policy1Disabled,
×
4870
                                MessageFlags:            r.Policy1MessageFlags,
×
4871
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4872
                                Signature:               r.Policy1Signature,
×
4873
                        }
×
4874
                }
×
4875
                if r.Policy2ID.Valid {
×
4876
                        policy2 = &sqlc.GraphChannelPolicy{
×
4877
                                ID:                      r.Policy2ID.Int64,
×
4878
                                Version:                 r.Policy2Version.Int16,
×
4879
                                ChannelID:               r.GraphChannel.ID,
×
4880
                                NodeID:                  r.Policy2NodeID.Int64,
×
4881
                                Timelock:                r.Policy2Timelock.Int32,
×
4882
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4883
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4884
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4885
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4886
                                LastUpdate:              r.Policy2LastUpdate,
×
4887
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4888
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4889
                                Disabled:                r.Policy2Disabled,
×
4890
                                MessageFlags:            r.Policy2MessageFlags,
×
4891
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4892
                                Signature:               r.Policy2Signature,
×
4893
                        }
×
4894
                }
×
4895

4896
                return policy1, policy2, nil
×
4897

4898
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4899
                if r.Policy1ID.Valid {
×
4900
                        policy1 = &sqlc.GraphChannelPolicy{
×
4901
                                ID:                      r.Policy1ID.Int64,
×
4902
                                Version:                 r.Policy1Version.Int16,
×
4903
                                ChannelID:               r.GraphChannel.ID,
×
4904
                                NodeID:                  r.Policy1NodeID.Int64,
×
4905
                                Timelock:                r.Policy1Timelock.Int32,
×
4906
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4907
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4908
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4909
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4910
                                LastUpdate:              r.Policy1LastUpdate,
×
4911
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4912
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4913
                                Disabled:                r.Policy1Disabled,
×
4914
                                MessageFlags:            r.Policy1MessageFlags,
×
4915
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4916
                                Signature:               r.Policy1Signature,
×
4917
                        }
×
4918
                }
×
4919
                if r.Policy2ID.Valid {
×
4920
                        policy2 = &sqlc.GraphChannelPolicy{
×
4921
                                ID:                      r.Policy2ID.Int64,
×
4922
                                Version:                 r.Policy2Version.Int16,
×
4923
                                ChannelID:               r.GraphChannel.ID,
×
4924
                                NodeID:                  r.Policy2NodeID.Int64,
×
4925
                                Timelock:                r.Policy2Timelock.Int32,
×
4926
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4927
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4928
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4929
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4930
                                LastUpdate:              r.Policy2LastUpdate,
×
4931
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4932
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4933
                                Disabled:                r.Policy2Disabled,
×
4934
                                MessageFlags:            r.Policy2MessageFlags,
×
4935
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4936
                                Signature:               r.Policy2Signature,
×
4937
                        }
×
4938
                }
×
4939

4940
                return policy1, policy2, nil
×
4941

4942
        case sqlc.GetChannelsByIDsRow:
×
4943
                if r.Policy1ID.Valid {
×
4944
                        policy1 = &sqlc.GraphChannelPolicy{
×
4945
                                ID:                      r.Policy1ID.Int64,
×
4946
                                Version:                 r.Policy1Version.Int16,
×
4947
                                ChannelID:               r.GraphChannel.ID,
×
4948
                                NodeID:                  r.Policy1NodeID.Int64,
×
4949
                                Timelock:                r.Policy1Timelock.Int32,
×
4950
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4951
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4952
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4953
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4954
                                LastUpdate:              r.Policy1LastUpdate,
×
4955
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4956
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4957
                                Disabled:                r.Policy1Disabled,
×
4958
                                MessageFlags:            r.Policy1MessageFlags,
×
4959
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4960
                                Signature:               r.Policy1Signature,
×
4961
                        }
×
4962
                }
×
4963
                if r.Policy2ID.Valid {
×
4964
                        policy2 = &sqlc.GraphChannelPolicy{
×
4965
                                ID:                      r.Policy2ID.Int64,
×
4966
                                Version:                 r.Policy2Version.Int16,
×
4967
                                ChannelID:               r.GraphChannel.ID,
×
4968
                                NodeID:                  r.Policy2NodeID.Int64,
×
4969
                                Timelock:                r.Policy2Timelock.Int32,
×
4970
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4971
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4972
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4973
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4974
                                LastUpdate:              r.Policy2LastUpdate,
×
4975
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4976
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4977
                                Disabled:                r.Policy2Disabled,
×
4978
                                MessageFlags:            r.Policy2MessageFlags,
×
4979
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4980
                                Signature:               r.Policy2Signature,
×
4981
                        }
×
4982
                }
×
4983

4984
                return policy1, policy2, nil
×
4985

4986
        default:
×
4987
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4988
                        "extractChannelPolicies: %T", r)
×
4989
        }
4990
}
4991

4992
// channelIDToBytes converts a channel ID (SCID) to a byte array
4993
// representation.
4994
func channelIDToBytes(channelID uint64) []byte {
×
4995
        var chanIDB [8]byte
×
4996
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4997

×
4998
        return chanIDB[:]
×
4999
}
×
5000

5001
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
5002
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
5003
        if len(addresses) == 0 {
×
5004
                return nil, nil
×
5005
        }
×
5006

5007
        result := make([]net.Addr, 0, len(addresses))
×
5008
        for _, addr := range addresses {
×
5009
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
5010
                if err != nil {
×
5011
                        return nil, fmt.Errorf("unable to parse address %s "+
×
5012
                                "of type %d: %w", addr.address, addr.addrType,
×
5013
                                err)
×
5014
                }
×
5015
                if netAddr != nil {
×
5016
                        result = append(result, netAddr)
×
5017
                }
×
5018
        }
5019

5020
        // If we have no valid addresses, return nil instead of empty slice.
5021
        if len(result) == 0 {
×
5022
                return nil, nil
×
5023
        }
×
5024

5025
        return result, nil
×
5026
}
5027

5028
// parseAddress parses the given address string based on the address type
5029
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
5030
// and opaque addresses.
5031
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
5032
        switch addrType {
×
5033
        case addressTypeIPv4:
×
5034
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
5035
                if err != nil {
×
5036
                        return nil, err
×
5037
                }
×
5038

5039
                tcp.IP = tcp.IP.To4()
×
5040

×
5041
                return tcp, nil
×
5042

5043
        case addressTypeIPv6:
×
5044
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
5045
                if err != nil {
×
5046
                        return nil, err
×
5047
                }
×
5048

5049
                return tcp, nil
×
5050

5051
        case addressTypeTorV3, addressTypeTorV2:
×
5052
                service, portStr, err := net.SplitHostPort(address)
×
5053
                if err != nil {
×
5054
                        return nil, fmt.Errorf("unable to split tor "+
×
5055
                                "address: %v", address)
×
5056
                }
×
5057

5058
                port, err := strconv.Atoi(portStr)
×
5059
                if err != nil {
×
5060
                        return nil, err
×
5061
                }
×
5062

5063
                return &tor.OnionAddr{
×
5064
                        OnionService: service,
×
5065
                        Port:         port,
×
5066
                }, nil
×
5067

5068
        case addressTypeDNS:
×
5069
                hostname, portStr, err := net.SplitHostPort(address)
×
5070
                if err != nil {
×
5071
                        return nil, fmt.Errorf("unable to split DNS "+
×
5072
                                "address: %v", address)
×
5073
                }
×
5074

5075
                port, err := strconv.Atoi(portStr)
×
5076
                if err != nil {
×
5077
                        return nil, err
×
5078
                }
×
5079

5080
                return &lnwire.DNSAddress{
×
5081
                        Hostname: hostname,
×
5082
                        Port:     uint16(port),
×
5083
                }, nil
×
5084

5085
        case addressTypeOpaque:
×
5086
                opaque, err := hex.DecodeString(address)
×
5087
                if err != nil {
×
5088
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5089
                                "address: %v", address)
×
5090
                }
×
5091

5092
                return &lnwire.OpaqueAddrs{
×
5093
                        Payload: opaque,
×
5094
                }, nil
×
5095

5096
        default:
×
5097
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5098
        }
5099
}
5100

5101
// batchNodeData holds all the related data for a batch of nodes.
5102
type batchNodeData struct {
5103
        // features is a map from a DB node ID to the feature bits for that
5104
        // node.
5105
        features map[int64][]int
5106

5107
        // addresses is a map from a DB node ID to the node's addresses.
5108
        addresses map[int64][]nodeAddress
5109

5110
        // extraFields is a map from a DB node ID to the extra signed fields
5111
        // for that node.
5112
        extraFields map[int64]map[uint64][]byte
5113
}
5114

5115
// nodeAddress holds the address type, position and address string for a
5116
// node. This is used to batch the fetching of node addresses.
5117
type nodeAddress struct {
5118
        addrType dbAddressType
5119
        position int32
5120
        address  string
5121
}
5122

5123
// batchLoadNodeData loads all related data for a batch of node IDs using the
5124
// provided SQLQueries interface. It returns a batchNodeData instance containing
5125
// the node features, addresses and extra signed fields.
5126
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5127
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5128

×
5129
        // Batch load the node features.
×
5130
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5131
        if err != nil {
×
5132
                return nil, fmt.Errorf("unable to batch load node "+
×
5133
                        "features: %w", err)
×
5134
        }
×
5135

5136
        // Batch load the node addresses.
5137
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5138
        if err != nil {
×
5139
                return nil, fmt.Errorf("unable to batch load node "+
×
5140
                        "addresses: %w", err)
×
5141
        }
×
5142

5143
        // Batch load the node extra signed fields.
5144
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5145
        if err != nil {
×
5146
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5147
                        "signed fields: %w", err)
×
5148
        }
×
5149

5150
        return &batchNodeData{
×
5151
                features:    features,
×
5152
                addresses:   addrs,
×
5153
                extraFields: extraTypes,
×
5154
        }, nil
×
5155
}
5156

5157
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5158
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5159
func batchLoadNodeFeaturesHelper(ctx context.Context,
5160
        cfg *sqldb.QueryConfig, db SQLQueries,
5161
        nodeIDs []int64) (map[int64][]int, error) {
×
5162

×
5163
        features := make(map[int64][]int)
×
5164

×
5165
        return features, sqldb.ExecuteBatchQuery(
×
5166
                ctx, cfg, nodeIDs,
×
5167
                func(id int64) int64 {
×
5168
                        return id
×
5169
                },
×
5170
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5171
                        error) {
×
5172

×
5173
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5174
                },
×
5175
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5176
                        features[feature.NodeID] = append(
×
5177
                                features[feature.NodeID],
×
5178
                                int(feature.FeatureBit),
×
5179
                        )
×
5180

×
5181
                        return nil
×
5182
                },
×
5183
        )
5184
}
5185

5186
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5187
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5188
// node ID to a slice of nodeAddress structs.
5189
func batchLoadNodeAddressesHelper(ctx context.Context,
5190
        cfg *sqldb.QueryConfig, db SQLQueries,
5191
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5192

×
5193
        addrs := make(map[int64][]nodeAddress)
×
5194

×
5195
        return addrs, sqldb.ExecuteBatchQuery(
×
5196
                ctx, cfg, nodeIDs,
×
5197
                func(id int64) int64 {
×
5198
                        return id
×
5199
                },
×
5200
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5201
                        error) {
×
5202

×
5203
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5204
                },
×
5205
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5206
                        addrs[addr.NodeID] = append(
×
5207
                                addrs[addr.NodeID], nodeAddress{
×
5208
                                        addrType: dbAddressType(addr.Type),
×
5209
                                        position: addr.Position,
×
5210
                                        address:  addr.Address,
×
5211
                                },
×
5212
                        )
×
5213

×
5214
                        return nil
×
5215
                },
×
5216
        )
5217
}
5218

5219
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5220
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5221
// query.
5222
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5223
        cfg *sqldb.QueryConfig, db SQLQueries,
5224
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5225

×
5226
        extraFields := make(map[int64]map[uint64][]byte)
×
5227

×
5228
        callback := func(ctx context.Context,
×
5229
                field sqlc.GraphNodeExtraType) error {
×
5230

×
5231
                if extraFields[field.NodeID] == nil {
×
5232
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5233
                }
×
5234
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5235

×
5236
                return nil
×
5237
        }
5238

5239
        return extraFields, sqldb.ExecuteBatchQuery(
×
5240
                ctx, cfg, nodeIDs,
×
5241
                func(id int64) int64 {
×
5242
                        return id
×
5243
                },
×
5244
                func(ctx context.Context, ids []int64) (
5245
                        []sqlc.GraphNodeExtraType, error) {
×
5246

×
5247
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5248
                },
×
5249
                callback,
5250
        )
5251
}
5252

5253
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5254
// from the provided sqlc.GraphChannelPolicy records and the
5255
// provided batchChannelData.
5256
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5257
        channelID uint64, node1, node2 route.Vertex,
5258
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5259
        *models.ChannelEdgePolicy, error) {
×
5260

×
5261
        pol1, err := buildChanPolicyWithBatchData(
×
5262
                dbPol1, channelID, node2, batchData,
×
5263
        )
×
5264
        if err != nil {
×
5265
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5266
        }
×
5267

5268
        pol2, err := buildChanPolicyWithBatchData(
×
5269
                dbPol2, channelID, node1, batchData,
×
5270
        )
×
5271
        if err != nil {
×
5272
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5273
        }
×
5274

5275
        return pol1, pol2, nil
×
5276
}
5277

5278
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5279
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5280
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5281
        channelID uint64, toNode route.Vertex,
5282
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5283

×
5284
        if dbPol == nil {
×
5285
                return nil, nil
×
5286
        }
×
5287

5288
        var dbPol1Extras map[uint64][]byte
×
5289
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5290
                dbPol1Extras = extras
×
5291
        } else {
×
5292
                dbPol1Extras = make(map[uint64][]byte)
×
5293
        }
×
5294

5295
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5296
}
5297

5298
// batchChannelData holds all the related data for a batch of channels.
5299
type batchChannelData struct {
5300
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5301
        chanfeatures map[int64][]int
5302

5303
        // chanExtras is a map from DB channel ID to a map of TLV type to
5304
        // extra signed field bytes.
5305
        chanExtraTypes map[int64]map[uint64][]byte
5306

5307
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5308
        // to extra signed field bytes.
5309
        policyExtras map[int64]map[uint64][]byte
5310
}
5311

5312
// batchLoadChannelData loads all related data for batches of channels and
5313
// policies.
5314
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5315
        db SQLQueries, channelIDs []int64,
5316
        policyIDs []int64) (*batchChannelData, error) {
×
5317

×
5318
        batchData := &batchChannelData{
×
5319
                chanfeatures:   make(map[int64][]int),
×
5320
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5321
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5322
        }
×
5323

×
5324
        // Batch load channel features and extras
×
5325
        var err error
×
5326
        if len(channelIDs) > 0 {
×
5327
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5328
                        ctx, cfg, db, channelIDs,
×
5329
                )
×
5330
                if err != nil {
×
5331
                        return nil, fmt.Errorf("unable to batch load "+
×
5332
                                "channel features: %w", err)
×
5333
                }
×
5334

5335
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5336
                        ctx, cfg, db, channelIDs,
×
5337
                )
×
5338
                if err != nil {
×
5339
                        return nil, fmt.Errorf("unable to batch load "+
×
5340
                                "channel extras: %w", err)
×
5341
                }
×
5342
        }
5343

5344
        if len(policyIDs) > 0 {
×
5345
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5346
                        ctx, cfg, db, policyIDs,
×
5347
                )
×
5348
                if err != nil {
×
5349
                        return nil, fmt.Errorf("unable to batch load "+
×
5350
                                "policy extras: %w", err)
×
5351
                }
×
5352
                batchData.policyExtras = policyExtras
×
5353
        }
5354

5355
        return batchData, nil
×
5356
}
5357

5358
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5359
// channel IDs using ExecuteBatchQuery wrapper around the
5360
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5361
// slice of feature bits.
5362
func batchLoadChannelFeaturesHelper(ctx context.Context,
5363
        cfg *sqldb.QueryConfig, db SQLQueries,
5364
        channelIDs []int64) (map[int64][]int, error) {
×
5365

×
5366
        features := make(map[int64][]int)
×
5367

×
5368
        return features, sqldb.ExecuteBatchQuery(
×
5369
                ctx, cfg, channelIDs,
×
5370
                func(id int64) int64 {
×
5371
                        return id
×
5372
                },
×
5373
                func(ctx context.Context,
5374
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5375

×
5376
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5377
                },
×
5378
                func(ctx context.Context,
5379
                        feature sqlc.GraphChannelFeature) error {
×
5380

×
5381
                        features[feature.ChannelID] = append(
×
5382
                                features[feature.ChannelID],
×
5383
                                int(feature.FeatureBit),
×
5384
                        )
×
5385

×
5386
                        return nil
×
5387
                },
×
5388
        )
5389
}
5390

5391
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5392
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5393
// query. It returns a map from DB channel ID to a map of TLV type to extra
5394
// signed field bytes.
5395
func batchLoadChannelExtrasHelper(ctx context.Context,
5396
        cfg *sqldb.QueryConfig, db SQLQueries,
5397
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5398

×
5399
        extras := make(map[int64]map[uint64][]byte)
×
5400

×
5401
        cb := func(ctx context.Context,
×
5402
                extra sqlc.GraphChannelExtraType) error {
×
5403

×
5404
                if extras[extra.ChannelID] == nil {
×
5405
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5406
                }
×
5407
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5408

×
5409
                return nil
×
5410
        }
5411

5412
        return extras, sqldb.ExecuteBatchQuery(
×
5413
                ctx, cfg, channelIDs,
×
5414
                func(id int64) int64 {
×
5415
                        return id
×
5416
                },
×
5417
                func(ctx context.Context,
5418
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5419

×
5420
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5421
                }, cb,
×
5422
        )
5423
}
5424

5425
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5426
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5427
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5428
// a map of TLV type to extra signed field bytes.
5429
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5430
        cfg *sqldb.QueryConfig, db SQLQueries,
5431
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5432

×
5433
        extras := make(map[int64]map[uint64][]byte)
×
5434

×
5435
        return extras, sqldb.ExecuteBatchQuery(
×
5436
                ctx, cfg, policyIDs,
×
5437
                func(id int64) int64 {
×
5438
                        return id
×
5439
                },
×
5440
                func(ctx context.Context, ids []int64) (
5441
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5442

×
5443
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5444
                },
×
5445
                func(ctx context.Context,
5446
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5447

×
5448
                        if extras[row.PolicyID] == nil {
×
5449
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5450
                        }
×
5451
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5452

×
5453
                        return nil
×
5454
                },
5455
        )
5456
}
5457

5458
// forEachNodePaginated executes a paginated query to process each node in the
5459
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5460
// and applies the provided processNode function to each node.
5461
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5462
        db SQLQueries, protocol lnwire.GossipVersion,
5463
        processNode func(context.Context, int64,
5464
                *models.Node) error) error {
×
5465

×
5466
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5467
                limit int32) ([]sqlc.GraphNode, error) {
×
5468

×
5469
                return db.ListNodesPaginated(
×
5470
                        ctx, sqlc.ListNodesPaginatedParams{
×
5471
                                Version: int16(protocol),
×
5472
                                ID:      lastID,
×
5473
                                Limit:   limit,
×
5474
                        },
×
5475
                )
×
5476
        }
×
5477

5478
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5479
                return node.ID
×
5480
        }
×
5481

5482
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5483
                return node.ID, nil
×
5484
        }
×
5485

5486
        batchQueryFunc := func(ctx context.Context,
×
5487
                nodeIDs []int64) (*batchNodeData, error) {
×
5488

×
5489
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5490
        }
×
5491

5492
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5493
                batchData *batchNodeData) error {
×
5494

×
5495
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5496
                if err != nil {
×
5497
                        return fmt.Errorf("unable to build "+
×
5498
                                "node(id=%d): %w", dbNode.ID, err)
×
5499
                }
×
5500

5501
                return processNode(ctx, dbNode.ID, node)
×
5502
        }
5503

5504
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5505
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5506
                collectFunc, batchQueryFunc, processItem,
×
5507
        )
×
5508
}
5509

5510
// forEachChannelWithPolicies executes a paginated query to process each channel
5511
// with policies in the graph.
5512
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5513
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5514
                *models.ChannelEdgePolicy,
5515
                *models.ChannelEdgePolicy) error) error {
×
5516

×
5517
        type channelBatchIDs struct {
×
5518
                channelID int64
×
5519
                policyIDs []int64
×
5520
        }
×
5521

×
5522
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5523
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5524
                error) {
×
5525

×
5526
                return db.ListChannelsWithPoliciesPaginated(
×
5527
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5528
                                Version: int16(lnwire.GossipVersion1),
×
5529
                                ID:      lastID,
×
5530
                                Limit:   limit,
×
5531
                        },
×
5532
                )
×
5533
        }
×
5534

5535
        extractPageCursor := func(
×
5536
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5537

×
5538
                return row.GraphChannel.ID
×
5539
        }
×
5540

5541
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5542
                channelBatchIDs, error) {
×
5543

×
5544
                ids := channelBatchIDs{
×
5545
                        channelID: row.GraphChannel.ID,
×
5546
                }
×
5547

×
5548
                // Extract policy IDs from the row.
×
5549
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5550
                if err != nil {
×
5551
                        return ids, err
×
5552
                }
×
5553

5554
                if dbPol1 != nil {
×
5555
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5556
                }
×
5557
                if dbPol2 != nil {
×
5558
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5559
                }
×
5560

5561
                return ids, nil
×
5562
        }
5563

5564
        batchDataFunc := func(ctx context.Context,
×
5565
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5566

×
5567
                // Separate channel IDs from policy IDs.
×
5568
                var (
×
5569
                        channelIDs = make([]int64, len(allIDs))
×
5570
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5571
                )
×
5572

×
5573
                for i, ids := range allIDs {
×
5574
                        channelIDs[i] = ids.channelID
×
5575
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5576
                }
×
5577

5578
                return batchLoadChannelData(
×
5579
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5580
                )
×
5581
        }
5582

5583
        processItem := func(ctx context.Context,
×
5584
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5585
                batchData *batchChannelData) error {
×
5586

×
5587
                node1, node2, err := buildNodeVertices(
×
5588
                        row.Node1Pubkey, row.Node2Pubkey,
×
5589
                )
×
5590
                if err != nil {
×
5591
                        return err
×
5592
                }
×
5593

5594
                edge, err := buildEdgeInfoWithBatchData(
×
5595
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5596
                        batchData,
×
5597
                )
×
5598
                if err != nil {
×
5599
                        return fmt.Errorf("unable to build channel info: %w",
×
5600
                                err)
×
5601
                }
×
5602

5603
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5604
                if err != nil {
×
5605
                        return err
×
5606
                }
×
5607

5608
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5609
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5610
                )
×
5611
                if err != nil {
×
5612
                        return err
×
5613
                }
×
5614

5615
                return processChannel(edge, p1, p2)
×
5616
        }
5617

5618
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5619
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5620
                collectFunc, batchDataFunc, processItem,
×
5621
        )
×
5622
}
5623

5624
// buildDirectedChannel builds a DirectedChannel instance from the provided
5625
// data.
5626
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5627
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5628
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5629
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5630

×
5631
        node1, node2, err := buildNodeVertices(
×
5632
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5633
        )
×
5634
        if err != nil {
×
5635
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5636
        }
×
5637

5638
        edge, err := buildEdgeInfoWithBatchData(
×
5639
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5640
        )
×
5641
        if err != nil {
×
5642
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5643
        }
×
5644

5645
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5646
        if err != nil {
×
5647
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5648
                        err)
×
5649
        }
×
5650

5651
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5652
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5653
                channelBatchData,
×
5654
        )
×
5655
        if err != nil {
×
5656
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5657
                        err)
×
5658
        }
×
5659

5660
        // Determine outgoing and incoming policy for this specific node.
5661
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5662
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5663
        outPolicy, inPolicy := p1, p2
×
5664
        if (p1 != nil && p1ToNode == nodeID) ||
×
5665
                (p2 != nil && p2ToNode != nodeID) {
×
5666

×
5667
                outPolicy, inPolicy = p2, p1
×
5668
        }
×
5669

5670
        // Build cached policy.
5671
        var cachedInPolicy *models.CachedEdgePolicy
×
5672
        if inPolicy != nil {
×
5673
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5674
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5675
                cachedInPolicy.ToNodeFeatures = features
×
5676
        }
×
5677

5678
        // Extract inbound fee.
5679
        var inboundFee lnwire.Fee
×
5680
        if outPolicy != nil {
×
5681
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5682
                        inboundFee = fee
×
5683
                })
×
5684
        }
5685

5686
        // Build directed channel.
5687
        directedChannel := &DirectedChannel{
×
5688
                ChannelID:    edge.ChannelID,
×
5689
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5690
                OtherNode:    edge.NodeKey2Bytes,
×
5691
                Capacity:     edge.Capacity,
×
5692
                OutPolicySet: outPolicy != nil,
×
5693
                InPolicy:     cachedInPolicy,
×
5694
                InboundFee:   inboundFee,
×
5695
        }
×
5696

×
5697
        if nodePub == edge.NodeKey2Bytes {
×
5698
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5699
        }
×
5700

5701
        return directedChannel, nil
×
5702
}
5703

5704
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5705
// provided rows. It uses batch loading for channels, policies, and nodes.
5706
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5707
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5708

×
5709
        var (
×
5710
                channelIDs = make([]int64, len(rows))
×
5711
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5712
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5713

×
5714
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5715
                nodeIDSet = make(map[int64]bool)
×
5716

×
5717
                // edges will hold the final channel edges built from the rows.
×
5718
                edges = make([]ChannelEdge, 0, len(rows))
×
5719
        )
×
5720

×
5721
        // Collect all IDs needed for batch loading.
×
5722
        for i, row := range rows {
×
5723
                channelIDs[i] = row.Channel().ID
×
5724

×
5725
                // Collect policy IDs
×
5726
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5727
                if err != nil {
×
5728
                        return nil, fmt.Errorf("unable to extract channel "+
×
5729
                                "policies: %w", err)
×
5730
                }
×
5731
                if dbPol1 != nil {
×
5732
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5733
                }
×
5734
                if dbPol2 != nil {
×
5735
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5736
                }
×
5737

5738
                var (
×
5739
                        node1ID = row.Node1().ID
×
5740
                        node2ID = row.Node2().ID
×
5741
                )
×
5742

×
5743
                // Collect unique node IDs.
×
5744
                if !nodeIDSet[node1ID] {
×
5745
                        nodeIDs = append(nodeIDs, node1ID)
×
5746
                        nodeIDSet[node1ID] = true
×
5747
                }
×
5748

5749
                if !nodeIDSet[node2ID] {
×
5750
                        nodeIDs = append(nodeIDs, node2ID)
×
5751
                        nodeIDSet[node2ID] = true
×
5752
                }
×
5753
        }
5754

5755
        // Batch the data for all the channels and policies.
5756
        channelBatchData, err := batchLoadChannelData(
×
5757
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5758
        )
×
5759
        if err != nil {
×
5760
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5761
                        "policy data: %w", err)
×
5762
        }
×
5763

5764
        // Batch the data for all the nodes.
5765
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5766
        if err != nil {
×
5767
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5768
                        err)
×
5769
        }
×
5770

5771
        // Build all channel edges using batch data.
5772
        for _, row := range rows {
×
5773
                // Build nodes using batch data.
×
5774
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5775
                if err != nil {
×
5776
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5777
                }
×
5778

5779
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5780
                if err != nil {
×
5781
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5782
                }
×
5783

5784
                // Build channel info using batch data.
5785
                channel, err := buildEdgeInfoWithBatchData(
×
5786
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5787
                        node2.PubKeyBytes, channelBatchData,
×
5788
                )
×
5789
                if err != nil {
×
5790
                        return nil, fmt.Errorf("unable to build channel "+
×
5791
                                "info: %w", err)
×
5792
                }
×
5793

5794
                // Extract and build policies using batch data.
5795
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5796
                if err != nil {
×
5797
                        return nil, fmt.Errorf("unable to extract channel "+
×
5798
                                "policies: %w", err)
×
5799
                }
×
5800

5801
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5802
                        dbPol1, dbPol2, channel.ChannelID,
×
5803
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5804
                )
×
5805
                if err != nil {
×
5806
                        return nil, fmt.Errorf("unable to build channel "+
×
5807
                                "policies: %w", err)
×
5808
                }
×
5809

5810
                edges = append(edges, ChannelEdge{
×
5811
                        Info:    channel,
×
5812
                        Policy1: p1,
×
5813
                        Policy2: p2,
×
5814
                        Node1:   node1,
×
5815
                        Node2:   node2,
×
5816
                })
×
5817
        }
5818

5819
        return edges, nil
×
5820
}
5821

5822
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5823
// instances from the provided rows using batch loading for channel data.
5824
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5825
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5826
        []*models.ChannelEdgeInfo, []int64, error) {
×
5827

×
5828
        if len(rows) == 0 {
×
5829
                return nil, nil, nil
×
5830
        }
×
5831

5832
        // Collect all the channel IDs needed for batch loading.
5833
        channelIDs := make([]int64, len(rows))
×
5834
        for i, row := range rows {
×
5835
                channelIDs[i] = row.Channel().ID
×
5836
        }
×
5837

5838
        // Batch load the channel data.
5839
        channelBatchData, err := batchLoadChannelData(
×
5840
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5841
        )
×
5842
        if err != nil {
×
5843
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5844
                        "data: %w", err)
×
5845
        }
×
5846

5847
        // Build all channel edges using batch data.
5848
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5849
        for _, row := range rows {
×
5850
                node1, node2, err := buildNodeVertices(
×
5851
                        row.Node1Pub(), row.Node2Pub(),
×
5852
                )
×
5853
                if err != nil {
×
5854
                        return nil, nil, err
×
5855
                }
×
5856

5857
                // Build channel info using batch data
5858
                info, err := buildEdgeInfoWithBatchData(
×
5859
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5860
                        channelBatchData,
×
5861
                )
×
5862
                if err != nil {
×
5863
                        return nil, nil, err
×
5864
                }
×
5865

5866
                edges = append(edges, info)
×
5867
        }
5868

5869
        return edges, channelIDs, nil
×
5870
}
5871

5872
// handleZombieMarking is a helper function that handles the logic of
5873
// marking a channel as a zombie in the database. It takes into account whether
5874
// we are in strict zombie pruning mode, and adjusts the node public keys
5875
// accordingly based on the last update timestamps of the channel policies.
5876
func handleZombieMarking(ctx context.Context, db SQLQueries,
5877
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5878
        strictZombiePruning bool, scid uint64) error {
×
5879

×
5880
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5881

×
5882
        if strictZombiePruning {
×
5883
                var e1UpdateTime, e2UpdateTime *time.Time
×
5884
                if row.Policy1LastUpdate.Valid {
×
5885
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5886
                        e1UpdateTime = &e1Time
×
5887
                }
×
5888
                if row.Policy2LastUpdate.Valid {
×
5889
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5890
                        e2UpdateTime = &e2Time
×
5891
                }
×
5892

5893
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5894
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5895
                        e2UpdateTime,
×
5896
                )
×
5897
        }
5898

5899
        return db.UpsertZombieChannel(
×
5900
                ctx, sqlc.UpsertZombieChannelParams{
×
5901
                        Version:  int16(lnwire.GossipVersion1),
×
5902
                        Scid:     channelIDToBytes(scid),
×
5903
                        NodeKey1: nodeKey1[:],
×
5904
                        NodeKey2: nodeKey2[:],
×
5905
                },
×
5906
        )
×
5907
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc