• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 19512294102

19 Nov 2025 06:31PM UTC coverage: 65.184% (-0.06%) from 65.243%
19512294102

Pull #10339

github

web-flow
Merge 5aad165f2 into 194a9f759
Pull Request #10339: [g175:2] graph/db: v2 columns and v2 node CRUD

205 of 389 new or added lines in 11 files covered. (52.7%)

119 existing lines in 28 files now uncovered.

137707 of 211260 relevant lines covered (65.18%)

20761.01 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
46
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
47
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
48
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
49
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
50
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
51
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
52
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
53
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
54
        DeleteNode(ctx context.Context, id int64) error
55

56
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
57
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
58
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
59
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
60

61
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
62
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
63
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
64
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
65

66
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
67
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
68
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
69
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
70
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
71

72
        /*
73
                Source node queries.
74
        */
75
        AddSourceNode(ctx context.Context, nodeID int64) error
76
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
77

78
        /*
79
                Channel queries.
80
        */
81
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
82
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
83
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
84
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
85
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
86
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
87
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
88
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
89
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
90
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
91
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
92
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
93
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
94
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
95
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
96
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
97
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
98
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
99
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
100
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
101
        DeleteChannels(ctx context.Context, ids []int64) error
102

103
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
104
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
105
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
106
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
107

108
        /*
109
                Channel Policy table queries.
110
        */
111
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
112
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
113
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
114

115
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
116
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
117
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
118

119
        /*
120
                Zombie index queries.
121
        */
122
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
123
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
124
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
125
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
126
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
127
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
128

129
        /*
130
                Prune log table queries.
131
        */
132
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
133
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
134
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
135
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
136
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
137

138
        /*
139
                Closed SCID table queries.
140
        */
141
        InsertClosedChannel(ctx context.Context, scid []byte) error
142
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
143
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
144

145
        /*
146
                Migration specific queries.
147

148
                NOTE: these should not be used in code other than migrations.
149
                Once sqldbv2 is in place, these can be removed from this struct
150
                as then migrations will have their own dedicated queries
151
                structs.
152
        */
153
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
154
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
155
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
156
}
157

158
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
159
// database operations.
160
type BatchedSQLQueries interface {
161
        SQLQueries
162
        sqldb.BatchedTx[SQLQueries]
163
}
164

165
// SQLStore is an implementation of the Store interface that uses a SQL
166
// database as the backend.
167
type SQLStore struct {
168
        cfg *SQLStoreConfig
169
        db  BatchedSQLQueries
170

171
        // cacheMu guards all caches (rejectCache and chanCache). If
172
        // this mutex will be acquired at the same time as the DB mutex then
173
        // the cacheMu MUST be acquired first to prevent deadlock.
174
        cacheMu     sync.RWMutex
175
        rejectCache *rejectCache
176
        chanCache   *channelCache
177

178
        chanScheduler batch.Scheduler[SQLQueries]
179
        nodeScheduler batch.Scheduler[SQLQueries]
180

181
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
182
        srcNodeMu sync.Mutex
183
}
184

185
// A compile-time assertion to ensure that SQLStore implements the Store
186
// interface.
187
var _ Store = (*SQLStore)(nil)
188

189
// SQLStoreConfig holds the configuration for the SQLStore.
190
type SQLStoreConfig struct {
191
        // ChainHash is the genesis hash for the chain that all the gossip
192
        // messages in this store are aimed at.
193
        ChainHash chainhash.Hash
194

195
        // QueryConfig holds configuration values for SQL queries.
196
        QueryCfg *sqldb.QueryConfig
197
}
198

199
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
200
// storage backend.
201
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
202
        options ...StoreOptionModifier) (*SQLStore, error) {
×
203

×
204
        opts := DefaultOptions()
×
205
        for _, o := range options {
×
206
                o(opts)
×
207
        }
×
208

209
        if opts.NoMigration {
×
210
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
211
                        "supported for SQL stores")
×
212
        }
×
213

214
        s := &SQLStore{
×
215
                cfg:         cfg,
×
216
                db:          db,
×
217
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
218
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
219
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
220
        }
×
221

×
222
        s.chanScheduler = batch.NewTimeScheduler(
×
223
                db, &s.cacheMu, opts.BatchCommitInterval,
×
224
        )
×
225
        s.nodeScheduler = batch.NewTimeScheduler(
×
226
                db, nil, opts.BatchCommitInterval,
×
227
        )
×
228

×
229
        return s, nil
×
230
}
231

232
// AddNode adds a vertex/node to the graph database. If the node is not
233
// in the database from before, this will add a new, unconnected one to the
234
// graph. If it is present from before, this will update that node's
235
// information.
236
//
237
// NOTE: part of the Store interface.
238
func (s *SQLStore) AddNode(ctx context.Context,
239
        node *models.Node, opts ...batch.SchedulerOption) error {
×
240

×
241
        r := &batch.Request[SQLQueries]{
×
242
                Opts: batch.NewSchedulerOptions(opts...),
×
243
                Do: func(queries SQLQueries) error {
×
244
                        _, err := upsertNode(ctx, queries, node)
×
245

×
246
                        // It is possible that two of the same node
×
247
                        // announcements are both being processed in the same
×
248
                        // batch. This may case the UpsertNode conflict to
×
249
                        // be hit since we require at the db layer that the
×
250
                        // new last_update is greater than the existing
×
251
                        // last_update. We need to gracefully handle this here.
×
252
                        if errors.Is(err, sql.ErrNoRows) {
×
253
                                return nil
×
254
                        }
×
255

256
                        return err
×
257
                },
258
        }
259

260
        return s.nodeScheduler.Execute(ctx, r)
×
261
}
262

263
// FetchNode attempts to look up a target node by its identity public
264
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
265
// returned.
266
//
267
// NOTE: part of the Store interface.
268
func (s *SQLStore) FetchNode(ctx context.Context, v lnwire.GossipVersion,
269
        pubKey route.Vertex) (*models.Node, error) {
×
270

×
271
        var node *models.Node
×
272
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
273
                var err error
×
NEW
274
                _, node, err = getNodeByPubKey(
×
NEW
275
                        ctx, s.cfg.QueryCfg, db, v, pubKey,
×
NEW
276
                )
×
277

×
278
                return err
×
279
        }, sqldb.NoOpReset)
×
280
        if err != nil {
×
281
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
282
        }
×
283

284
        return node, nil
×
285
}
286

287
// HasV1Node determines if the graph has a vertex identified by the
288
// target node identity public key. If the node exists in the database, a
289
// timestamp of when the data for the node was lasted updated is returned along
290
// with a true boolean. Otherwise, an empty time.Time is returned with a false
291
// boolean.
292
//
293
// NOTE: part of the Store interface.
294
func (s *SQLStore) HasV1Node(ctx context.Context,
295
        pubKey [33]byte) (time.Time, bool, error) {
×
296

×
297
        var (
×
298
                exists     bool
×
299
                lastUpdate time.Time
×
300
        )
×
301
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
302
                dbNode, err := db.GetNodeByPubKey(
×
303
                        ctx, sqlc.GetNodeByPubKeyParams{
×
304
                                Version: int16(lnwire.GossipVersion1),
×
305
                                PubKey:  pubKey[:],
×
306
                        },
×
307
                )
×
308
                if errors.Is(err, sql.ErrNoRows) {
×
309
                        return nil
×
310
                } else if err != nil {
×
311
                        return fmt.Errorf("unable to fetch node: %w", err)
×
312
                }
×
313

314
                exists = true
×
315

×
316
                if dbNode.LastUpdate.Valid {
×
317
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
318
                }
×
319

320
                return nil
×
321
        }, sqldb.NoOpReset)
322
        if err != nil {
×
323
                return time.Time{}, false,
×
324
                        fmt.Errorf("unable to fetch node: %w", err)
×
325
        }
×
326

327
        return lastUpdate, exists, nil
×
328
}
329

330
// HasNode determines if the graph has a vertex identified by the
331
// target node identity public key.
332
//
333
// NOTE: part of the Store interface.
334
func (s *SQLStore) HasNode(ctx context.Context, v lnwire.GossipVersion,
NEW
335
        pubKey [33]byte) (bool, error) {
×
NEW
336

×
NEW
337
        var exists bool
×
NEW
338
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
339
                _, err := db.GetNodeByPubKey(
×
NEW
340
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
341
                                Version: int16(v),
×
NEW
342
                                PubKey:  pubKey[:],
×
NEW
343
                        },
×
NEW
344
                )
×
NEW
345
                if errors.Is(err, sql.ErrNoRows) {
×
NEW
346
                        return nil
×
NEW
347
                } else if err != nil {
×
NEW
348
                        return fmt.Errorf("unable to fetch node: %w", err)
×
NEW
349
                }
×
350

NEW
351
                exists = true
×
NEW
352

×
NEW
353
                return nil
×
354
        }, sqldb.NoOpReset)
NEW
355
        if err != nil {
×
NEW
356
                return false, fmt.Errorf("unable to fetch node: %w", err)
×
NEW
357
        }
×
358

NEW
359
        return exists, nil
×
360
}
361

362
// AddrsForNode returns all known addresses for the target node public key
363
// that the graph DB is aware of. The returned boolean indicates if the
364
// given node is unknown to the graph DB or not.
365
//
366
// NOTE: part of the Store interface.
367
func (s *SQLStore) AddrsForNode(ctx context.Context, v lnwire.GossipVersion,
368
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
369

×
370
        var (
×
371
                addresses []net.Addr
×
372
                known     bool
×
373
        )
×
374
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
375
                // First, check if the node exists and get its DB ID if it
×
376
                // does.
×
377
                dbID, err := db.GetNodeIDByPubKey(
×
378
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
379
                                Version: int16(v),
×
380
                                PubKey:  nodePub.SerializeCompressed(),
×
381
                        },
×
382
                )
×
383
                if errors.Is(err, sql.ErrNoRows) {
×
384
                        return nil
×
385
                }
×
386

387
                known = true
×
388

×
389
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
390
                if err != nil {
×
391
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
392
                                err)
×
393
                }
×
394

395
                return nil
×
396
        }, sqldb.NoOpReset)
397
        if err != nil {
×
398
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
399
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
400
        }
×
401

402
        return known, addresses, nil
×
403
}
404

405
// DeleteNode starts a new database transaction to remove a vertex/node
406
// from the database according to the node's public key.
407
//
408
// NOTE: part of the Store interface.
409
func (s *SQLStore) DeleteNode(ctx context.Context, v lnwire.GossipVersion,
410
        pubKey route.Vertex) error {
×
411

×
412
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
413
                res, err := db.DeleteNodeByPubKey(
×
414
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
NEW
415
                                Version: int16(v),
×
416
                                PubKey:  pubKey[:],
×
417
                        },
×
418
                )
×
419
                if err != nil {
×
420
                        return err
×
421
                }
×
422

423
                rows, err := res.RowsAffected()
×
424
                if err != nil {
×
425
                        return err
×
426
                }
×
427

428
                if rows == 0 {
×
429
                        return ErrGraphNodeNotFound
×
430
                } else if rows > 1 {
×
431
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
432
                }
×
433

434
                return err
×
435
        }, sqldb.NoOpReset)
436
        if err != nil {
×
437
                return fmt.Errorf("unable to delete node: %w", err)
×
438
        }
×
439

440
        return nil
×
441
}
442

443
// FetchNodeFeatures returns the features of the given node. If no features are
444
// known for the node, an empty feature vector is returned.
445
//
446
// NOTE: this is part of the graphdb.NodeTraverser interface.
447
func (s *SQLStore) FetchNodeFeatures(v lnwire.GossipVersion,
NEW
448
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
449

×
450
        ctx := context.TODO()
×
451

×
NEW
452
        return fetchNodeFeatures(ctx, s.db, v, nodePub)
×
453
}
×
454

455
// DisabledChannelIDs returns the channel ids of disabled channels.
456
// A channel is disabled when two of the associated ChanelEdgePolicies
457
// have their disabled bit on.
458
//
459
// NOTE: part of the Store interface.
460
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
461
        var (
×
462
                ctx     = context.TODO()
×
463
                chanIDs []uint64
×
464
        )
×
465
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
466
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
467
                if err != nil {
×
468
                        return fmt.Errorf("unable to fetch disabled "+
×
469
                                "channels: %w", err)
×
470
                }
×
471

472
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
473

×
474
                return nil
×
475
        }, sqldb.NoOpReset)
476
        if err != nil {
×
477
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
478
                        err)
×
479
        }
×
480

481
        return chanIDs, nil
×
482
}
483

484
// LookupAlias attempts to return the alias as advertised by the target node.
485
//
486
// NOTE: part of the Store interface.
487
func (s *SQLStore) LookupAlias(ctx context.Context, v lnwire.GossipVersion,
488
        pub *btcec.PublicKey) (string, error) {
×
489

×
490
        var alias string
×
491
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
492
                dbNode, err := db.GetNodeByPubKey(
×
493
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
494
                                Version: int16(v),
×
495
                                PubKey:  pub.SerializeCompressed(),
×
496
                        },
×
497
                )
×
498
                if errors.Is(err, sql.ErrNoRows) {
×
499
                        return ErrNodeAliasNotFound
×
500
                } else if err != nil {
×
501
                        return fmt.Errorf("unable to fetch node: %w", err)
×
502
                }
×
503

504
                if !dbNode.Alias.Valid {
×
505
                        return ErrNodeAliasNotFound
×
506
                }
×
507

508
                alias = dbNode.Alias.String
×
509

×
510
                return nil
×
511
        }, sqldb.NoOpReset)
512
        if err != nil {
×
513
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
514
        }
×
515

516
        return alias, nil
×
517
}
518

519
// SourceNode returns the source node of the graph. The source node is treated
520
// as the center node within a star-graph. This method may be used to kick off
521
// a path finding algorithm in order to explore the reachability of another
522
// node based off the source node.
523
//
524
// NOTE: part of the Store interface.
525
func (s *SQLStore) SourceNode(ctx context.Context,
NEW
526
        v lnwire.GossipVersion) (*models.Node, error) {
×
527

×
528
        var node *models.Node
×
529
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
530
                _, nodePub, err := s.getSourceNode(ctx, db, v)
×
531
                if err != nil {
×
532
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
533
                                err)
×
534
                }
×
535

NEW
536
                _, node, err = getNodeByPubKey(
×
NEW
537
                        ctx, s.cfg.QueryCfg, db, v, nodePub,
×
NEW
538
                )
×
539

×
540
                return err
×
541
        }, sqldb.NoOpReset)
542
        if err != nil {
×
543
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
544
        }
×
545

546
        return node, nil
×
547
}
548

549
// SetSourceNode sets the source node within the graph database. The source
550
// node is to be used as the center of a star-graph within path finding
551
// algorithms.
552
//
553
// NOTE: part of the Store interface.
554
func (s *SQLStore) SetSourceNode(ctx context.Context,
555
        node *models.Node) error {
×
556

×
557
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
558
                id, err := upsertNode(ctx, db, node)
×
559
                if err != nil {
×
560
                        return fmt.Errorf("unable to upsert source node: %w",
×
561
                                err)
×
562
                }
×
563

564
                // Make sure that if a source node for this version is already
565
                // set, then the ID is the same as the one we are about to set.
566
                dbSourceNodeID, _, err := s.getSourceNode(
×
567
                        ctx, db, lnwire.GossipVersion1,
×
568
                )
×
569
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
570
                        return fmt.Errorf("unable to fetch source node: %w",
×
571
                                err)
×
572
                } else if err == nil {
×
573
                        if dbSourceNodeID != id {
×
574
                                return fmt.Errorf("v1 source node already "+
×
575
                                        "set to a different node: %d vs %d",
×
576
                                        dbSourceNodeID, id)
×
577
                        }
×
578

579
                        return nil
×
580
                }
581

582
                return db.AddSourceNode(ctx, id)
×
583
        }, sqldb.NoOpReset)
584
}
585

586
// NodeUpdatesInHorizon returns all the known lightning node which have an
587
// update timestamp within the passed range. This method can be used by two
588
// nodes to quickly determine if they have the same set of up to date node
589
// announcements.
590
//
591
// NOTE: This is part of the Store interface.
592
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
593
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
594

×
595
        cfg := defaultIteratorConfig()
×
596
        for _, opt := range opts {
×
597
                opt(cfg)
×
598
        }
×
599

600
        return func(yield func(*models.Node, error) bool) {
×
601
                var (
×
602
                        ctx            = context.TODO()
×
603
                        lastUpdateTime sql.NullInt64
×
604
                        lastPubKey     = make([]byte, 33)
×
605
                        hasMore        = true
×
606
                )
×
607

×
608
                // Each iteration, we'll read a batch amount of nodes, yield
×
609
                // them, then decide is we have more or not.
×
610
                for hasMore {
×
611
                        var batch []*models.Node
×
612

×
613
                        //nolint:ll
×
614
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
615
                                //nolint:ll
×
616
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
617
                                        StartTime: sqldb.SQLInt64(
×
618
                                                startTime.Unix(),
×
619
                                        ),
×
620
                                        EndTime: sqldb.SQLInt64(
×
621
                                                endTime.Unix(),
×
622
                                        ),
×
623
                                        LastUpdate: lastUpdateTime,
×
624
                                        LastPubKey: lastPubKey,
×
625
                                        OnlyPublic: sql.NullBool{
×
626
                                                Bool:  cfg.iterPublicNodes,
×
627
                                                Valid: true,
×
628
                                        },
×
629
                                        MaxResults: sqldb.SQLInt32(
×
630
                                                cfg.nodeUpdateIterBatchSize,
×
631
                                        ),
×
632
                                }
×
633
                                rows, err := db.GetNodesByLastUpdateRange(
×
634
                                        ctx, params,
×
635
                                )
×
636
                                if err != nil {
×
637
                                        return err
×
638
                                }
×
639

640
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
641

×
642
                                err = forEachNodeInBatch(
×
643
                                        ctx, s.cfg.QueryCfg, db, rows,
×
644
                                        func(_ int64, node *models.Node) error {
×
645
                                                batch = append(batch, node)
×
646

×
647
                                                // Update pagination cursors
×
648
                                                // based on the last processed
×
649
                                                // node.
×
650
                                                lastUpdateTime = sql.NullInt64{
×
651
                                                        Int64: node.LastUpdate.
×
652
                                                                Unix(),
×
653
                                                        Valid: true,
×
654
                                                }
×
655
                                                lastPubKey = node.PubKeyBytes[:]
×
656

×
657
                                                return nil
×
658
                                        },
×
659
                                )
660
                                if err != nil {
×
661
                                        return fmt.Errorf("unable to build "+
×
662
                                                "nodes: %w", err)
×
663
                                }
×
664

665
                                return nil
×
666
                        }, func() {
×
667
                                batch = []*models.Node{}
×
668
                        })
×
669

670
                        if err != nil {
×
671
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
672
                                        "error: %v", err)
×
673

×
674
                                yield(&models.Node{}, err)
×
675

×
676
                                return
×
677
                        }
×
678

679
                        for _, node := range batch {
×
680
                                if !yield(node, nil) {
×
681
                                        return
×
682
                                }
×
683
                        }
684

685
                        // If the batch didn't yield anything, then we're done.
686
                        if len(batch) == 0 {
×
687
                                break
×
688
                        }
689
                }
690
        }
691
}
692

693
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
694
// undirected edge from the two target nodes are created. The information stored
695
// denotes the static attributes of the channel, such as the channelID, the keys
696
// involved in creation of the channel, and the set of features that the channel
697
// supports. The chanPoint and chanID are used to uniquely identify the edge
698
// globally within the database.
699
//
700
// NOTE: part of the Store interface.
701
func (s *SQLStore) AddChannelEdge(ctx context.Context,
702
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
703

×
704
        var alreadyExists bool
×
705
        r := &batch.Request[SQLQueries]{
×
706
                Opts: batch.NewSchedulerOptions(opts...),
×
707
                Reset: func() {
×
708
                        alreadyExists = false
×
709
                },
×
710
                Do: func(tx SQLQueries) error {
×
711
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
712

×
713
                        // Make sure that the channel doesn't already exist. We
×
714
                        // do this explicitly instead of relying on catching a
×
715
                        // unique constraint error because relying on SQL to
×
716
                        // throw that error would abort the entire batch of
×
717
                        // transactions.
×
718
                        _, err := tx.GetChannelBySCID(
×
719
                                ctx, sqlc.GetChannelBySCIDParams{
×
720
                                        Scid:    chanIDB,
×
721
                                        Version: int16(lnwire.GossipVersion1),
×
722
                                },
×
723
                        )
×
724
                        if err == nil {
×
725
                                alreadyExists = true
×
726
                                return nil
×
727
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
728
                                return fmt.Errorf("unable to fetch channel: %w",
×
729
                                        err)
×
730
                        }
×
731

732
                        return insertChannel(ctx, tx, edge)
×
733
                },
734
                OnCommit: func(err error) error {
×
735
                        switch {
×
736
                        case err != nil:
×
737
                                return err
×
738
                        case alreadyExists:
×
739
                                return ErrEdgeAlreadyExist
×
740
                        default:
×
741
                                s.rejectCache.remove(edge.ChannelID)
×
742
                                s.chanCache.remove(edge.ChannelID)
×
743
                                return nil
×
744
                        }
745
                },
746
        }
747

748
        return s.chanScheduler.Execute(ctx, r)
×
749
}
750

751
// HighestChanID returns the "highest" known channel ID in the channel graph.
752
// This represents the "newest" channel from the PoV of the chain. This method
753
// can be used by peers to quickly determine if their graphs are in sync.
754
//
755
// NOTE: This is part of the Store interface.
756
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
757
        var highestChanID uint64
×
758
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
759
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
760
                if errors.Is(err, sql.ErrNoRows) {
×
761
                        return nil
×
762
                } else if err != nil {
×
763
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
764
                                err)
×
765
                }
×
766

767
                highestChanID = byteOrder.Uint64(chanID)
×
768

×
769
                return nil
×
770
        }, sqldb.NoOpReset)
771
        if err != nil {
×
772
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
773
        }
×
774

775
        return highestChanID, nil
×
776
}
777

778
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
779
// within the database for the referenced channel. The `flags` attribute within
780
// the ChannelEdgePolicy determines which of the directed edges are being
781
// updated. If the flag is 1, then the first node's information is being
782
// updated, otherwise it's the second node's information. The node ordering is
783
// determined by the lexicographical ordering of the identity public keys of the
784
// nodes on either side of the channel.
785
//
786
// NOTE: part of the Store interface.
787
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
788
        edge *models.ChannelEdgePolicy,
789
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
790

×
791
        var (
×
792
                isUpdate1    bool
×
793
                edgeNotFound bool
×
794
                from, to     route.Vertex
×
795
        )
×
796

×
797
        r := &batch.Request[SQLQueries]{
×
798
                Opts: batch.NewSchedulerOptions(opts...),
×
799
                Reset: func() {
×
800
                        isUpdate1 = false
×
801
                        edgeNotFound = false
×
802
                },
×
803
                Do: func(tx SQLQueries) error {
×
804
                        var err error
×
805
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
806
                                ctx, tx, edge,
×
807
                        )
×
808
                        // It is possible that two of the same policy
×
809
                        // announcements are both being processed in the same
×
810
                        // batch. This may case the UpsertEdgePolicy conflict to
×
811
                        // be hit since we require at the db layer that the
×
812
                        // new last_update is greater than the existing
×
813
                        // last_update. We need to gracefully handle this here.
×
814
                        if errors.Is(err, sql.ErrNoRows) {
×
815
                                return nil
×
816
                        } else if err != nil {
×
817
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
818
                        }
×
819

820
                        // Silence ErrEdgeNotFound so that the batch can
821
                        // succeed, but propagate the error via local state.
822
                        if errors.Is(err, ErrEdgeNotFound) {
×
823
                                edgeNotFound = true
×
824
                                return nil
×
825
                        }
×
826

827
                        return err
×
828
                },
829
                OnCommit: func(err error) error {
×
830
                        switch {
×
831
                        case err != nil:
×
832
                                return err
×
833
                        case edgeNotFound:
×
834
                                return ErrEdgeNotFound
×
835
                        default:
×
836
                                s.updateEdgeCache(edge, isUpdate1)
×
837
                                return nil
×
838
                        }
839
                },
840
        }
841

842
        err := s.chanScheduler.Execute(ctx, r)
×
843

×
844
        return from, to, err
×
845
}
846

847
// updateEdgeCache updates our reject and channel caches with the new
848
// edge policy information.
849
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
850
        isUpdate1 bool) {
×
851

×
852
        // If an entry for this channel is found in reject cache, we'll modify
×
853
        // the entry with the updated timestamp for the direction that was just
×
854
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
855
        // during the next query for this edge.
×
856
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
857
                if isUpdate1 {
×
858
                        entry.upd1Time = e.LastUpdate.Unix()
×
859
                } else {
×
860
                        entry.upd2Time = e.LastUpdate.Unix()
×
861
                }
×
862
                s.rejectCache.insert(e.ChannelID, entry)
×
863
        }
864

865
        // If an entry for this channel is found in channel cache, we'll modify
866
        // the entry with the updated policy for the direction that was just
867
        // written. If the edge doesn't exist, we'll defer loading the info and
868
        // policies and lazily read from disk during the next query.
869
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
870
                if isUpdate1 {
×
871
                        channel.Policy1 = e
×
872
                } else {
×
873
                        channel.Policy2 = e
×
874
                }
×
875
                s.chanCache.insert(e.ChannelID, channel)
×
876
        }
877
}
878

879
// ForEachSourceNodeChannel iterates through all channels of the source node,
880
// executing the passed callback on each. The call-back is provided with the
881
// channel's outpoint, whether we have a policy for the channel and the channel
882
// peer's node information.
883
//
884
// NOTE: part of the Store interface.
885
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
886
        cb func(chanPoint wire.OutPoint, havePolicy bool,
887
                otherNode *models.Node) error, reset func()) error {
×
888

×
889
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
890
                nodeID, nodePub, err := s.getSourceNode(
×
891
                        ctx, db, lnwire.GossipVersion1,
×
892
                )
×
893
                if err != nil {
×
894
                        return fmt.Errorf("unable to fetch source node: %w",
×
895
                                err)
×
896
                }
×
897

898
                return forEachNodeChannel(
×
899
                        ctx, db, s.cfg, nodeID,
×
900
                        func(info *models.ChannelEdgeInfo,
×
901
                                outPolicy *models.ChannelEdgePolicy,
×
902
                                _ *models.ChannelEdgePolicy) error {
×
903

×
904
                                // Fetch the other node.
×
905
                                var (
×
906
                                        otherNodePub [33]byte
×
907
                                        node1        = info.NodeKey1Bytes
×
908
                                        node2        = info.NodeKey2Bytes
×
909
                                )
×
910
                                switch {
×
911
                                case bytes.Equal(node1[:], nodePub[:]):
×
912
                                        otherNodePub = node2
×
913
                                case bytes.Equal(node2[:], nodePub[:]):
×
914
                                        otherNodePub = node1
×
915
                                default:
×
916
                                        return fmt.Errorf("node not " +
×
917
                                                "participating in this channel")
×
918
                                }
919

920
                                _, otherNode, err := getNodeByPubKey(
×
NEW
921
                                        ctx, s.cfg.QueryCfg, db,
×
NEW
922
                                        lnwire.GossipVersion1, otherNodePub,
×
923
                                )
×
924
                                if err != nil {
×
925
                                        return fmt.Errorf("unable to fetch "+
×
926
                                                "other node(%x): %w",
×
927
                                                otherNodePub, err)
×
928
                                }
×
929

930
                                return cb(
×
931
                                        info.ChannelPoint, outPolicy != nil,
×
932
                                        otherNode,
×
933
                                )
×
934
                        },
935
                )
936
        }, reset)
937
}
938

939
// ForEachNode iterates through all the stored vertices/nodes in the graph,
940
// executing the passed callback with each node encountered. If the callback
941
// returns an error, then the transaction is aborted and the iteration stops
942
// early.
943
//
944
// NOTE: part of the Store interface.
945
func (s *SQLStore) ForEachNode(ctx context.Context,
946
        cb func(node *models.Node) error, reset func()) error {
×
947

×
948
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
949
                return forEachNodePaginated(
×
950
                        ctx, s.cfg.QueryCfg, db,
×
951
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
952
                                node *models.Node) error {
×
953

×
954
                                return cb(node)
×
955
                        },
×
956
                )
957
        }, reset)
958
}
959

960
// ForEachNodeDirectedChannel iterates through all channels of a given node,
961
// executing the passed callback on the directed edge representing the channel
962
// and its incoming policy. If the callback returns an error, then the iteration
963
// is halted with the error propagated back up to the caller.
964
//
965
// Unknown policies are passed into the callback as nil values.
966
//
967
// NOTE: this is part of the graphdb.NodeTraverser interface.
968
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
969
        cb func(channel *DirectedChannel) error, reset func()) error {
×
970

×
971
        var ctx = context.TODO()
×
972

×
973
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
974
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
975
        }, reset)
×
976
}
977

978
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
979
// graph, executing the passed callback with each node encountered. If the
980
// callback returns an error, then the transaction is aborted and the iteration
981
// stops early.
982
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
983
        cb func(route.Vertex, *lnwire.FeatureVector) error,
984
        reset func()) error {
×
985

×
986
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
987
                return forEachNodeCacheable(
×
988
                        ctx, s.cfg.QueryCfg, db,
×
989
                        func(_ int64, nodePub route.Vertex,
×
990
                                features *lnwire.FeatureVector) error {
×
991

×
992
                                return cb(nodePub, features)
×
993
                        },
×
994
                )
995
        }, reset)
996
        if err != nil {
×
997
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
998
        }
×
999

1000
        return nil
×
1001
}
1002

1003
// ForEachNodeChannel iterates through all channels of the given node,
1004
// executing the passed callback with an edge info structure and the policies
1005
// of each end of the channel. The first edge policy is the outgoing edge *to*
1006
// the connecting node, while the second is the incoming edge *from* the
1007
// connecting node. If the callback returns an error, then the iteration is
1008
// halted with the error propagated back up to the caller.
1009
//
1010
// Unknown policies are passed into the callback as nil values.
1011
//
1012
// NOTE: part of the Store interface.
1013
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
1014
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1015
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1016

×
1017
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1018
                dbNode, err := db.GetNodeByPubKey(
×
1019
                        ctx, sqlc.GetNodeByPubKeyParams{
×
1020
                                Version: int16(lnwire.GossipVersion1),
×
1021
                                PubKey:  nodePub[:],
×
1022
                        },
×
1023
                )
×
1024
                if errors.Is(err, sql.ErrNoRows) {
×
1025
                        return nil
×
1026
                } else if err != nil {
×
1027
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1028
                }
×
1029

1030
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1031
        }, reset)
1032
}
1033

1034
// extractMaxUpdateTime returns the maximum of the two policy update times.
1035
// This is used for pagination cursor tracking.
1036
func extractMaxUpdateTime(
1037
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1038

×
1039
        switch {
×
1040
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1041
                return max(row.Policy1LastUpdate.Int64,
×
1042
                        row.Policy2LastUpdate.Int64)
×
1043
        case row.Policy1LastUpdate.Valid:
×
1044
                return row.Policy1LastUpdate.Int64
×
1045
        case row.Policy2LastUpdate.Valid:
×
1046
                return row.Policy2LastUpdate.Int64
×
1047
        default:
×
1048
                return 0
×
1049
        }
1050
}
1051

1052
// buildChannelFromRow constructs a ChannelEdge from a database row.
1053
// This includes building the nodes, channel info, and policies.
1054
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1055
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1056

×
1057
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1058
        if err != nil {
×
1059
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1060
                        err)
×
1061
        }
×
1062

1063
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1064
        if err != nil {
×
1065
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1066
                        err)
×
1067
        }
×
1068

1069
        channel, err := getAndBuildEdgeInfo(
×
1070
                ctx, s.cfg, db,
×
1071
                row.GraphChannel, node1.PubKeyBytes,
×
1072
                node2.PubKeyBytes,
×
1073
        )
×
1074
        if err != nil {
×
1075
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1076
                        "channel info: %w", err)
×
1077
        }
×
1078

1079
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1080
        if err != nil {
×
1081
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1082
                        "channel policies: %w", err)
×
1083
        }
×
1084

1085
        p1, p2, err := getAndBuildChanPolicies(
×
1086
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1087
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1088
        )
×
1089
        if err != nil {
×
1090
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1091
                        "channel policies: %w", err)
×
1092
        }
×
1093

1094
        return ChannelEdge{
×
1095
                Info:    channel,
×
1096
                Policy1: p1,
×
1097
                Policy2: p2,
×
1098
                Node1:   node1,
×
1099
                Node2:   node2,
×
1100
        }, nil
×
1101
}
1102

1103
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1104
// This method acquires the cache lock only once for the entire batch.
1105
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1106
        if len(edgesToCache) == 0 {
×
1107
                return
×
1108
        }
×
1109

1110
        s.cacheMu.Lock()
×
1111
        defer s.cacheMu.Unlock()
×
1112

×
1113
        for chanID, edge := range edgesToCache {
×
1114
                s.chanCache.insert(chanID, edge)
×
1115
        }
×
1116
}
1117

1118
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1119
// one edge that has an update timestamp within the specified horizon.
1120
//
1121
// Iterator Lifecycle:
1122
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1123
// 2. Query batch of channels with policies in time range
1124
// 3. For each channel: check if seen, check cache, or build from DB
1125
// 4. Yield channels to caller
1126
// 5. Update cache after successful batch
1127
// 6. Repeat with updated pagination cursor until no more results
1128
//
1129
// NOTE: This is part of the Store interface.
1130
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1131
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1132

×
1133
        // Apply options.
×
1134
        cfg := defaultIteratorConfig()
×
1135
        for _, opt := range opts {
×
1136
                opt(cfg)
×
1137
        }
×
1138

1139
        return func(yield func(ChannelEdge, error) bool) {
×
1140
                var (
×
1141
                        ctx            = context.TODO()
×
1142
                        edgesSeen      = make(map[uint64]struct{})
×
1143
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1144
                        hits           int
×
1145
                        total          int
×
1146
                        lastUpdateTime sql.NullInt64
×
1147
                        lastID         sql.NullInt64
×
1148
                        hasMore        = true
×
1149
                )
×
1150

×
1151
                // Each iteration, we'll read a batch amount of channel updates
×
1152
                // (consulting the cache along the way), yield them, then loop
×
1153
                // back to decide if we have any more updates to read out.
×
1154
                for hasMore {
×
1155
                        var batch []ChannelEdge
×
1156

×
1157
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1158
                                func(db SQLQueries) error {
×
1159
                                        //nolint:ll
×
1160
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1161
                                                Version: int16(lnwire.GossipVersion1),
×
1162
                                                StartTime: sqldb.SQLInt64(
×
1163
                                                        startTime.Unix(),
×
1164
                                                ),
×
1165
                                                EndTime: sqldb.SQLInt64(
×
1166
                                                        endTime.Unix(),
×
1167
                                                ),
×
1168
                                                LastUpdateTime: lastUpdateTime,
×
1169
                                                LastID:         lastID,
×
1170
                                                MaxResults: sql.NullInt32{
×
1171
                                                        Int32: int32(
×
1172
                                                                cfg.chanUpdateIterBatchSize,
×
1173
                                                        ),
×
1174
                                                        Valid: true,
×
1175
                                                },
×
1176
                                        }
×
1177
                                        //nolint:ll
×
1178
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1179
                                                ctx, params,
×
1180
                                        )
×
1181
                                        if err != nil {
×
1182
                                                return err
×
1183
                                        }
×
1184

1185
                                        //nolint:ll
1186
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1187

×
1188
                                        //nolint:ll
×
1189
                                        for _, row := range rows {
×
1190
                                                lastUpdateTime = sql.NullInt64{
×
1191
                                                        Int64: extractMaxUpdateTime(row),
×
1192
                                                        Valid: true,
×
1193
                                                }
×
1194
                                                lastID = sql.NullInt64{
×
1195
                                                        Int64: row.GraphChannel.ID,
×
1196
                                                        Valid: true,
×
1197
                                                }
×
1198

×
1199
                                                // Skip if we've already
×
1200
                                                // processed this channel.
×
1201
                                                chanIDInt := byteOrder.Uint64(
×
1202
                                                        row.GraphChannel.Scid,
×
1203
                                                )
×
1204
                                                _, ok := edgesSeen[chanIDInt]
×
1205
                                                if ok {
×
1206
                                                        continue
×
1207
                                                }
1208

1209
                                                s.cacheMu.RLock()
×
1210
                                                channel, ok := s.chanCache.get(
×
1211
                                                        chanIDInt,
×
1212
                                                )
×
1213
                                                s.cacheMu.RUnlock()
×
1214
                                                if ok {
×
1215
                                                        hits++
×
1216
                                                        total++
×
1217
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1218
                                                        batch = append(batch, channel)
×
1219

×
1220
                                                        continue
×
1221
                                                }
1222

1223
                                                chanEdge, err := s.buildChannelFromRow(
×
1224
                                                        ctx, db, row,
×
1225
                                                )
×
1226
                                                if err != nil {
×
1227
                                                        return err
×
1228
                                                }
×
1229

1230
                                                edgesSeen[chanIDInt] = struct{}{}
×
1231
                                                edgesToCache[chanIDInt] = chanEdge
×
1232

×
1233
                                                batch = append(batch, chanEdge)
×
1234

×
1235
                                                total++
×
1236
                                        }
1237

1238
                                        return nil
×
1239
                                }, func() {
×
1240
                                        batch = nil
×
1241
                                        edgesSeen = make(map[uint64]struct{})
×
1242
                                        edgesToCache = make(
×
1243
                                                map[uint64]ChannelEdge,
×
1244
                                        )
×
1245
                                })
×
1246

1247
                        if err != nil {
×
1248
                                log.Errorf("ChanUpdatesInHorizon "+
×
1249
                                        "batch error: %v", err)
×
1250

×
1251
                                yield(ChannelEdge{}, err)
×
1252

×
1253
                                return
×
1254
                        }
×
1255

1256
                        for _, edge := range batch {
×
1257
                                if !yield(edge, nil) {
×
1258
                                        return
×
1259
                                }
×
1260
                        }
1261

1262
                        // Update cache after successful batch yield, setting
1263
                        // the cache lock only once for the entire batch.
1264
                        s.updateChanCacheBatch(edgesToCache)
×
1265
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1266

×
1267
                        // If the batch didn't yield anything, then we're done.
×
1268
                        if len(batch) == 0 {
×
1269
                                break
×
1270
                        }
1271
                }
1272

1273
                if total > 0 {
×
1274
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1275
                                "%.2f (%d/%d)",
×
1276
                                float64(hits)*100/float64(total), hits, total)
×
1277
                } else {
×
1278
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1279
                                "in horizon (%s, %s)", startTime, endTime)
×
1280
                }
×
1281
        }
1282
}
1283

1284
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1285
// data to the call-back. If withAddrs is true, then the call-back will also be
1286
// provided with the addresses associated with the node. The address retrieval
1287
// result in an additional round-trip to the database, so it should only be used
1288
// if the addresses are actually needed.
1289
//
1290
// NOTE: part of the Store interface.
1291
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1292
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1293
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1294

×
1295
        type nodeCachedBatchData struct {
×
1296
                features      map[int64][]int
×
1297
                addrs         map[int64][]nodeAddress
×
1298
                chanBatchData *batchChannelData
×
1299
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1300
        }
×
1301

×
1302
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1303
                // pageQueryFunc is used to query the next page of nodes.
×
1304
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1305
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1306

×
1307
                        return db.ListNodeIDsAndPubKeys(
×
1308
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1309
                                        Version: int16(lnwire.GossipVersion1),
×
1310
                                        ID:      lastID,
×
1311
                                        Limit:   limit,
×
1312
                                },
×
1313
                        )
×
1314
                }
×
1315

1316
                // batchDataFunc is then used to batch load the data required
1317
                // for each page of nodes.
1318
                batchDataFunc := func(ctx context.Context,
×
1319
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1320

×
1321
                        // Batch load node features.
×
1322
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1323
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1324
                        )
×
1325
                        if err != nil {
×
1326
                                return nil, fmt.Errorf("unable to batch load "+
×
1327
                                        "node features: %w", err)
×
1328
                        }
×
1329

1330
                        // Maybe fetch the node's addresses if requested.
1331
                        var nodeAddrs map[int64][]nodeAddress
×
1332
                        if withAddrs {
×
1333
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1334
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1335
                                )
×
1336
                                if err != nil {
×
1337
                                        return nil, fmt.Errorf("unable to "+
×
1338
                                                "batch load node "+
×
1339
                                                "addresses: %w", err)
×
1340
                                }
×
1341
                        }
1342

1343
                        // Batch load ALL unique channels for ALL nodes in this
1344
                        // page.
1345
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1346
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1347
                                        Version:  int16(lnwire.GossipVersion1),
×
1348
                                        Node1Ids: nodeIDs,
×
1349
                                        Node2Ids: nodeIDs,
×
1350
                                },
×
1351
                        )
×
1352
                        if err != nil {
×
1353
                                return nil, fmt.Errorf("unable to batch "+
×
1354
                                        "fetch channels for nodes: %w", err)
×
1355
                        }
×
1356

1357
                        // Deduplicate channels and collect IDs.
1358
                        var (
×
1359
                                allChannelIDs []int64
×
1360
                                allPolicyIDs  []int64
×
1361
                        )
×
1362
                        uniqueChannels := make(
×
1363
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1364
                        )
×
1365

×
1366
                        for _, channel := range allChannels {
×
1367
                                channelID := channel.GraphChannel.ID
×
1368

×
1369
                                // Only process each unique channel once.
×
1370
                                _, exists := uniqueChannels[channelID]
×
1371
                                if exists {
×
1372
                                        continue
×
1373
                                }
1374

1375
                                uniqueChannels[channelID] = channel
×
1376
                                allChannelIDs = append(allChannelIDs, channelID)
×
1377

×
1378
                                if channel.Policy1ID.Valid {
×
1379
                                        allPolicyIDs = append(
×
1380
                                                allPolicyIDs,
×
1381
                                                channel.Policy1ID.Int64,
×
1382
                                        )
×
1383
                                }
×
1384
                                if channel.Policy2ID.Valid {
×
1385
                                        allPolicyIDs = append(
×
1386
                                                allPolicyIDs,
×
1387
                                                channel.Policy2ID.Int64,
×
1388
                                        )
×
1389
                                }
×
1390
                        }
1391

1392
                        // Batch load channel data for all unique channels.
1393
                        channelBatchData, err := batchLoadChannelData(
×
1394
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1395
                                allPolicyIDs,
×
1396
                        )
×
1397
                        if err != nil {
×
1398
                                return nil, fmt.Errorf("unable to batch "+
×
1399
                                        "load channel data: %w", err)
×
1400
                        }
×
1401

1402
                        // Create map of node ID to channels that involve this
1403
                        // node.
1404
                        nodeIDSet := make(map[int64]bool)
×
1405
                        for _, nodeID := range nodeIDs {
×
1406
                                nodeIDSet[nodeID] = true
×
1407
                        }
×
1408

1409
                        nodeChannelMap := make(
×
1410
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1411
                        )
×
1412
                        for _, channel := range uniqueChannels {
×
1413
                                // Add channel to both nodes if they're in our
×
1414
                                // current page.
×
1415
                                node1 := channel.GraphChannel.NodeID1
×
1416
                                if nodeIDSet[node1] {
×
1417
                                        nodeChannelMap[node1] = append(
×
1418
                                                nodeChannelMap[node1], channel,
×
1419
                                        )
×
1420
                                }
×
1421
                                node2 := channel.GraphChannel.NodeID2
×
1422
                                if nodeIDSet[node2] {
×
1423
                                        nodeChannelMap[node2] = append(
×
1424
                                                nodeChannelMap[node2], channel,
×
1425
                                        )
×
1426
                                }
×
1427
                        }
1428

1429
                        return &nodeCachedBatchData{
×
1430
                                features:      nodeFeatures,
×
1431
                                addrs:         nodeAddrs,
×
1432
                                chanBatchData: channelBatchData,
×
1433
                                chanMap:       nodeChannelMap,
×
1434
                        }, nil
×
1435
                }
1436

1437
                // processItem is used to process each node in the current page.
1438
                processItem := func(ctx context.Context,
×
1439
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1440
                        batchData *nodeCachedBatchData) error {
×
1441

×
1442
                        // Build feature vector for this node.
×
1443
                        fv := lnwire.EmptyFeatureVector()
×
1444
                        features, exists := batchData.features[nodeData.ID]
×
1445
                        if exists {
×
1446
                                for _, bit := range features {
×
1447
                                        fv.Set(lnwire.FeatureBit(bit))
×
1448
                                }
×
1449
                        }
1450

1451
                        var nodePub route.Vertex
×
1452
                        copy(nodePub[:], nodeData.PubKey)
×
1453

×
1454
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1455

×
1456
                        toNodeCallback := func() route.Vertex {
×
1457
                                return nodePub
×
1458
                        }
×
1459

1460
                        // Build cached channels map for this node.
1461
                        channels := make(map[uint64]*DirectedChannel)
×
1462
                        for _, channelRow := range nodeChannels {
×
1463
                                directedChan, err := buildDirectedChannel(
×
1464
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1465
                                        channelRow, batchData.chanBatchData, fv,
×
1466
                                        toNodeCallback,
×
1467
                                )
×
1468
                                if err != nil {
×
1469
                                        return err
×
1470
                                }
×
1471

1472
                                channels[directedChan.ChannelID] = directedChan
×
1473
                        }
1474

1475
                        addrs, err := buildNodeAddresses(
×
1476
                                batchData.addrs[nodeData.ID],
×
1477
                        )
×
1478
                        if err != nil {
×
1479
                                return fmt.Errorf("unable to build node "+
×
1480
                                        "addresses: %w", err)
×
1481
                        }
×
1482

1483
                        return cb(ctx, nodePub, addrs, channels)
×
1484
                }
1485

1486
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1487
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1488
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1489
                                return node.ID
×
1490
                        },
×
1491
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1492
                                error) {
×
1493

×
1494
                                return node.ID, nil
×
1495
                        },
×
1496
                        batchDataFunc, processItem,
1497
                )
1498
        }, reset)
1499
}
1500

1501
// ForEachChannelCacheable iterates through all the channel edges stored
1502
// within the graph and invokes the passed callback for each edge. The
1503
// callback takes two edges as since this is a directed graph, both the
1504
// in/out edges are visited. If the callback returns an error, then the
1505
// transaction is aborted and the iteration stops early.
1506
//
1507
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1508
// pointer for that particular channel edge routing policy will be
1509
// passed into the callback.
1510
//
1511
// NOTE: this method is like ForEachChannel but fetches only the data
1512
// required for the graph cache.
1513
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1514
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1515
        reset func()) error {
×
1516

×
1517
        ctx := context.TODO()
×
1518

×
1519
        handleChannel := func(_ context.Context,
×
1520
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1521

×
1522
                node1, node2, err := buildNodeVertices(
×
1523
                        row.Node1Pubkey, row.Node2Pubkey,
×
1524
                )
×
1525
                if err != nil {
×
1526
                        return err
×
1527
                }
×
1528

1529
                edge := buildCacheableChannelInfo(
×
1530
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1531
                )
×
1532

×
1533
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1534
                if err != nil {
×
1535
                        return err
×
1536
                }
×
1537

1538
                pol1, pol2, err := buildCachedChanPolicies(
×
1539
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1540
                )
×
1541
                if err != nil {
×
1542
                        return err
×
1543
                }
×
1544

1545
                return cb(edge, pol1, pol2)
×
1546
        }
1547

1548
        extractCursor := func(
×
1549
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1550

×
1551
                return row.ID
×
1552
        }
×
1553

1554
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1555
                //nolint:ll
×
1556
                queryFunc := func(ctx context.Context, lastID int64,
×
1557
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1558
                        error) {
×
1559

×
1560
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1561
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1562
                                        Version: int16(lnwire.GossipVersion1),
×
1563
                                        ID:      lastID,
×
1564
                                        Limit:   limit,
×
1565
                                },
×
1566
                        )
×
1567
                }
×
1568

1569
                return sqldb.ExecutePaginatedQuery(
×
1570
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1571
                        extractCursor, handleChannel,
×
1572
                )
×
1573
        }, reset)
1574
}
1575

1576
// ForEachChannel iterates through all the channel edges stored within the
1577
// graph and invokes the passed callback for each edge. The callback takes two
1578
// edges as since this is a directed graph, both the in/out edges are visited.
1579
// If the callback returns an error, then the transaction is aborted and the
1580
// iteration stops early.
1581
//
1582
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1583
// for that particular channel edge routing policy will be passed into the
1584
// callback.
1585
//
1586
// NOTE: part of the Store interface.
1587
func (s *SQLStore) ForEachChannel(ctx context.Context,
1588
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1589
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1590

×
1591
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1592
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1593
        }, reset)
×
1594
}
1595

1596
// FilterChannelRange returns the channel ID's of all known channels which were
1597
// mined in a block height within the passed range. The channel IDs are grouped
1598
// by their common block height. This method can be used to quickly share with a
1599
// peer the set of channels we know of within a particular range to catch them
1600
// up after a period of time offline. If withTimestamps is true then the
1601
// timestamp info of the latest received channel update messages of the channel
1602
// will be included in the response.
1603
//
1604
// NOTE: This is part of the Store interface.
1605
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1606
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1607

×
1608
        var (
×
1609
                ctx       = context.TODO()
×
1610
                startSCID = &lnwire.ShortChannelID{
×
1611
                        BlockHeight: startHeight,
×
1612
                }
×
1613
                endSCID = lnwire.ShortChannelID{
×
1614
                        BlockHeight: endHeight,
×
1615
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1616
                        TxPosition:  math.MaxUint16,
×
1617
                }
×
1618
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1619
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1620
        )
×
1621

×
1622
        // 1) get all channels where channelID is between start and end chan ID.
×
1623
        // 2) skip if not public (ie, no channel_proof)
×
1624
        // 3) collect that channel.
×
1625
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1626
        //    and add those timestamps to the collected channel.
×
1627
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1628
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1629
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1630
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1631
                                StartScid: chanIDStart,
×
1632
                                EndScid:   chanIDEnd,
×
1633
                        },
×
1634
                )
×
1635
                if err != nil {
×
1636
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1637
                                err)
×
1638
                }
×
1639

1640
                for _, dbChan := range dbChans {
×
1641
                        cid := lnwire.NewShortChanIDFromInt(
×
1642
                                byteOrder.Uint64(dbChan.Scid),
×
1643
                        )
×
1644
                        chanInfo := NewChannelUpdateInfo(
×
1645
                                cid, time.Time{}, time.Time{},
×
1646
                        )
×
1647

×
1648
                        if !withTimestamps {
×
1649
                                channelsPerBlock[cid.BlockHeight] = append(
×
1650
                                        channelsPerBlock[cid.BlockHeight],
×
1651
                                        chanInfo,
×
1652
                                )
×
1653

×
1654
                                continue
×
1655
                        }
1656

1657
                        //nolint:ll
1658
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1659
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1660
                                        Version:   int16(lnwire.GossipVersion1),
×
1661
                                        ChannelID: dbChan.ID,
×
1662
                                        NodeID:    dbChan.NodeID1,
×
1663
                                },
×
1664
                        )
×
1665
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1666
                                return fmt.Errorf("unable to fetch node1 "+
×
1667
                                        "policy: %w", err)
×
1668
                        } else if err == nil {
×
1669
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1670
                                        node1Policy.LastUpdate.Int64, 0,
×
1671
                                )
×
1672
                        }
×
1673

1674
                        //nolint:ll
1675
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1676
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1677
                                        Version:   int16(lnwire.GossipVersion1),
×
1678
                                        ChannelID: dbChan.ID,
×
1679
                                        NodeID:    dbChan.NodeID2,
×
1680
                                },
×
1681
                        )
×
1682
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1683
                                return fmt.Errorf("unable to fetch node2 "+
×
1684
                                        "policy: %w", err)
×
1685
                        } else if err == nil {
×
1686
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1687
                                        node2Policy.LastUpdate.Int64, 0,
×
1688
                                )
×
1689
                        }
×
1690

1691
                        channelsPerBlock[cid.BlockHeight] = append(
×
1692
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1693
                        )
×
1694
                }
1695

1696
                return nil
×
1697
        }, func() {
×
1698
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1699
        })
×
1700
        if err != nil {
×
1701
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1702
        }
×
1703

1704
        if len(channelsPerBlock) == 0 {
×
1705
                return nil, nil
×
1706
        }
×
1707

1708
        // Return the channel ranges in ascending block height order.
1709
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1710
        slices.Sort(blocks)
×
1711

×
1712
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1713
                return BlockChannelRange{
×
1714
                        Height:   block,
×
1715
                        Channels: channelsPerBlock[block],
×
1716
                }
×
1717
        }), nil
×
1718
}
1719

1720
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1721
// zombie. This method is used on an ad-hoc basis, when channels need to be
1722
// marked as zombies outside the normal pruning cycle.
1723
//
1724
// NOTE: part of the Store interface.
1725
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1726
        pubKey1, pubKey2 [33]byte) error {
×
1727

×
1728
        ctx := context.TODO()
×
1729

×
1730
        s.cacheMu.Lock()
×
1731
        defer s.cacheMu.Unlock()
×
1732

×
1733
        chanIDB := channelIDToBytes(chanID)
×
1734

×
1735
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1736
                return db.UpsertZombieChannel(
×
1737
                        ctx, sqlc.UpsertZombieChannelParams{
×
1738
                                Version:  int16(lnwire.GossipVersion1),
×
1739
                                Scid:     chanIDB,
×
1740
                                NodeKey1: pubKey1[:],
×
1741
                                NodeKey2: pubKey2[:],
×
1742
                        },
×
1743
                )
×
1744
        }, sqldb.NoOpReset)
×
1745
        if err != nil {
×
1746
                return fmt.Errorf("unable to upsert zombie channel "+
×
1747
                        "(channel_id=%d): %w", chanID, err)
×
1748
        }
×
1749

1750
        s.rejectCache.remove(chanID)
×
1751
        s.chanCache.remove(chanID)
×
1752

×
1753
        return nil
×
1754
}
1755

1756
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1757
//
1758
// NOTE: part of the Store interface.
1759
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1760
        s.cacheMu.Lock()
×
1761
        defer s.cacheMu.Unlock()
×
1762

×
1763
        var (
×
1764
                ctx     = context.TODO()
×
1765
                chanIDB = channelIDToBytes(chanID)
×
1766
        )
×
1767

×
1768
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1769
                res, err := db.DeleteZombieChannel(
×
1770
                        ctx, sqlc.DeleteZombieChannelParams{
×
1771
                                Scid:    chanIDB,
×
1772
                                Version: int16(lnwire.GossipVersion1),
×
1773
                        },
×
1774
                )
×
1775
                if err != nil {
×
1776
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1777
                                err)
×
1778
                }
×
1779

1780
                rows, err := res.RowsAffected()
×
1781
                if err != nil {
×
1782
                        return err
×
1783
                }
×
1784

1785
                if rows == 0 {
×
1786
                        return ErrZombieEdgeNotFound
×
1787
                } else if rows > 1 {
×
1788
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1789
                                "expected 1", rows)
×
1790
                }
×
1791

1792
                return nil
×
1793
        }, sqldb.NoOpReset)
1794
        if err != nil {
×
1795
                return fmt.Errorf("unable to mark edge live "+
×
1796
                        "(channel_id=%d): %w", chanID, err)
×
1797
        }
×
1798

1799
        s.rejectCache.remove(chanID)
×
1800
        s.chanCache.remove(chanID)
×
1801

×
1802
        return err
×
1803
}
1804

1805
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1806
// zombie, then the two node public keys corresponding to this edge are also
1807
// returned.
1808
//
1809
// NOTE: part of the Store interface.
1810
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1811
        error) {
×
1812

×
1813
        var (
×
1814
                ctx              = context.TODO()
×
1815
                isZombie         bool
×
1816
                pubKey1, pubKey2 route.Vertex
×
1817
                chanIDB          = channelIDToBytes(chanID)
×
1818
        )
×
1819

×
1820
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1821
                zombie, err := db.GetZombieChannel(
×
1822
                        ctx, sqlc.GetZombieChannelParams{
×
1823
                                Scid:    chanIDB,
×
1824
                                Version: int16(lnwire.GossipVersion1),
×
1825
                        },
×
1826
                )
×
1827
                if errors.Is(err, sql.ErrNoRows) {
×
1828
                        return nil
×
1829
                }
×
1830
                if err != nil {
×
1831
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1832
                                err)
×
1833
                }
×
1834

1835
                copy(pubKey1[:], zombie.NodeKey1)
×
1836
                copy(pubKey2[:], zombie.NodeKey2)
×
1837
                isZombie = true
×
1838

×
1839
                return nil
×
1840
        }, sqldb.NoOpReset)
1841
        if err != nil {
×
1842
                return false, route.Vertex{}, route.Vertex{},
×
1843
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1844
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1845
        }
×
1846

1847
        return isZombie, pubKey1, pubKey2, nil
×
1848
}
1849

1850
// NumZombies returns the current number of zombie channels in the graph.
1851
//
1852
// NOTE: part of the Store interface.
1853
func (s *SQLStore) NumZombies() (uint64, error) {
×
1854
        var (
×
1855
                ctx        = context.TODO()
×
1856
                numZombies uint64
×
1857
        )
×
1858
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1859
                count, err := db.CountZombieChannels(
×
1860
                        ctx, int16(lnwire.GossipVersion1),
×
1861
                )
×
1862
                if err != nil {
×
1863
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1864
                                err)
×
1865
                }
×
1866

1867
                numZombies = uint64(count)
×
1868

×
1869
                return nil
×
1870
        }, sqldb.NoOpReset)
1871
        if err != nil {
×
1872
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1873
        }
×
1874

1875
        return numZombies, nil
×
1876
}
1877

1878
// DeleteChannelEdges removes edges with the given channel IDs from the
1879
// database and marks them as zombies. This ensures that we're unable to re-add
1880
// it to our database once again. If an edge does not exist within the
1881
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1882
// true, then when we mark these edges as zombies, we'll set up the keys such
1883
// that we require the node that failed to send the fresh update to be the one
1884
// that resurrects the channel from its zombie state. The markZombie bool
1885
// denotes whether to mark the channel as a zombie.
1886
//
1887
// NOTE: part of the Store interface.
1888
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1889
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1890

×
1891
        s.cacheMu.Lock()
×
1892
        defer s.cacheMu.Unlock()
×
1893

×
1894
        // Keep track of which channels we end up finding so that we can
×
1895
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1896
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1897
        for _, chanID := range chanIDs {
×
1898
                chanLookup[chanID] = struct{}{}
×
1899
        }
×
1900

1901
        var (
×
1902
                ctx   = context.TODO()
×
1903
                edges []*models.ChannelEdgeInfo
×
1904
        )
×
1905
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1906
                // First, collect all channel rows.
×
1907
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1908
                chanCallBack := func(ctx context.Context,
×
1909
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1910

×
1911
                        // Deleting the entry from the map indicates that we
×
1912
                        // have found the channel.
×
1913
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1914
                        delete(chanLookup, scid)
×
1915

×
1916
                        channelRows = append(channelRows, row)
×
1917

×
1918
                        return nil
×
1919
                }
×
1920

1921
                err := s.forEachChanWithPoliciesInSCIDList(
×
1922
                        ctx, db, chanCallBack, chanIDs,
×
1923
                )
×
1924
                if err != nil {
×
1925
                        return err
×
1926
                }
×
1927

1928
                if len(chanLookup) > 0 {
×
1929
                        return ErrEdgeNotFound
×
1930
                }
×
1931

1932
                if len(channelRows) == 0 {
×
1933
                        return nil
×
1934
                }
×
1935

1936
                // Batch build all channel edges.
1937
                var chanIDsToDelete []int64
×
1938
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1939
                        ctx, s.cfg, db, channelRows,
×
1940
                )
×
1941
                if err != nil {
×
1942
                        return err
×
1943
                }
×
1944

1945
                if markZombie {
×
1946
                        for i, row := range channelRows {
×
1947
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1948

×
1949
                                err := handleZombieMarking(
×
1950
                                        ctx, db, row, edges[i],
×
1951
                                        strictZombiePruning, scid,
×
1952
                                )
×
1953
                                if err != nil {
×
1954
                                        return fmt.Errorf("unable to mark "+
×
1955
                                                "channel as zombie: %w", err)
×
1956
                                }
×
1957
                        }
1958
                }
1959

1960
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1961
        }, func() {
×
1962
                edges = nil
×
1963

×
1964
                // Re-fill the lookup map.
×
1965
                for _, chanID := range chanIDs {
×
1966
                        chanLookup[chanID] = struct{}{}
×
1967
                }
×
1968
        })
1969
        if err != nil {
×
1970
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1971
                        err)
×
1972
        }
×
1973

1974
        for _, chanID := range chanIDs {
×
1975
                s.rejectCache.remove(chanID)
×
1976
                s.chanCache.remove(chanID)
×
1977
        }
×
1978

1979
        return edges, nil
×
1980
}
1981

1982
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1983
// channel identified by the channel ID. If the channel can't be found, then
1984
// ErrEdgeNotFound is returned. A struct which houses the general information
1985
// for the channel itself is returned as well as two structs that contain the
1986
// routing policies for the channel in either direction.
1987
//
1988
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1989
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1990
// the ChannelEdgeInfo will only include the public keys of each node.
1991
//
1992
// NOTE: part of the Store interface.
1993
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1994
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1995
        *models.ChannelEdgePolicy, error) {
×
1996

×
1997
        var (
×
1998
                ctx              = context.TODO()
×
1999
                edge             *models.ChannelEdgeInfo
×
2000
                policy1, policy2 *models.ChannelEdgePolicy
×
2001
                chanIDB          = channelIDToBytes(chanID)
×
2002
        )
×
2003
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2004
                row, err := db.GetChannelBySCIDWithPolicies(
×
2005
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
2006
                                Scid:    chanIDB,
×
2007
                                Version: int16(lnwire.GossipVersion1),
×
2008
                        },
×
2009
                )
×
2010
                if errors.Is(err, sql.ErrNoRows) {
×
2011
                        // First check if this edge is perhaps in the zombie
×
2012
                        // index.
×
2013
                        zombie, err := db.GetZombieChannel(
×
2014
                                ctx, sqlc.GetZombieChannelParams{
×
2015
                                        Scid:    chanIDB,
×
2016
                                        Version: int16(lnwire.GossipVersion1),
×
2017
                                },
×
2018
                        )
×
2019
                        if errors.Is(err, sql.ErrNoRows) {
×
2020
                                return ErrEdgeNotFound
×
2021
                        } else if err != nil {
×
2022
                                return fmt.Errorf("unable to check if "+
×
2023
                                        "channel is zombie: %w", err)
×
2024
                        }
×
2025

2026
                        // At this point, we know the channel is a zombie, so
2027
                        // we'll return an error indicating this, and we will
2028
                        // populate the edge info with the public keys of each
2029
                        // party as this is the only information we have about
2030
                        // it.
2031
                        edge = &models.ChannelEdgeInfo{}
×
2032
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
2033
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
2034

×
2035
                        return ErrZombieEdge
×
2036
                } else if err != nil {
×
2037
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2038
                }
×
2039

2040
                node1, node2, err := buildNodeVertices(
×
2041
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2042
                )
×
2043
                if err != nil {
×
2044
                        return err
×
2045
                }
×
2046

2047
                edge, err = getAndBuildEdgeInfo(
×
2048
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2049
                )
×
2050
                if err != nil {
×
2051
                        return fmt.Errorf("unable to build channel info: %w",
×
2052
                                err)
×
2053
                }
×
2054

2055
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2056
                if err != nil {
×
2057
                        return fmt.Errorf("unable to extract channel "+
×
2058
                                "policies: %w", err)
×
2059
                }
×
2060

2061
                policy1, policy2, err = getAndBuildChanPolicies(
×
2062
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2063
                        node1, node2,
×
2064
                )
×
2065
                if err != nil {
×
2066
                        return fmt.Errorf("unable to build channel "+
×
2067
                                "policies: %w", err)
×
2068
                }
×
2069

2070
                return nil
×
2071
        }, sqldb.NoOpReset)
2072
        if err != nil {
×
2073
                // If we are returning the ErrZombieEdge, then we also need to
×
2074
                // return the edge info as the method comment indicates that
×
2075
                // this will be populated when the edge is a zombie.
×
2076
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2077
                        err)
×
2078
        }
×
2079

2080
        return edge, policy1, policy2, nil
×
2081
}
2082

2083
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2084
// the channel identified by the funding outpoint. If the channel can't be
2085
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2086
// information for the channel itself is returned as well as two structs that
2087
// contain the routing policies for the channel in either direction.
2088
//
2089
// NOTE: part of the Store interface.
2090
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2091
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2092
        *models.ChannelEdgePolicy, error) {
×
2093

×
2094
        var (
×
2095
                ctx              = context.TODO()
×
2096
                edge             *models.ChannelEdgeInfo
×
2097
                policy1, policy2 *models.ChannelEdgePolicy
×
2098
        )
×
2099
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2100
                row, err := db.GetChannelByOutpointWithPolicies(
×
2101
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2102
                                Outpoint: op.String(),
×
2103
                                Version:  int16(lnwire.GossipVersion1),
×
2104
                        },
×
2105
                )
×
2106
                if errors.Is(err, sql.ErrNoRows) {
×
2107
                        return ErrEdgeNotFound
×
2108
                } else if err != nil {
×
2109
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2110
                }
×
2111

2112
                node1, node2, err := buildNodeVertices(
×
2113
                        row.Node1Pubkey, row.Node2Pubkey,
×
2114
                )
×
2115
                if err != nil {
×
2116
                        return err
×
2117
                }
×
2118

2119
                edge, err = getAndBuildEdgeInfo(
×
2120
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2121
                )
×
2122
                if err != nil {
×
2123
                        return fmt.Errorf("unable to build channel info: %w",
×
2124
                                err)
×
2125
                }
×
2126

2127
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2128
                if err != nil {
×
2129
                        return fmt.Errorf("unable to extract channel "+
×
2130
                                "policies: %w", err)
×
2131
                }
×
2132

2133
                policy1, policy2, err = getAndBuildChanPolicies(
×
2134
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2135
                        node1, node2,
×
2136
                )
×
2137
                if err != nil {
×
2138
                        return fmt.Errorf("unable to build channel "+
×
2139
                                "policies: %w", err)
×
2140
                }
×
2141

2142
                return nil
×
2143
        }, sqldb.NoOpReset)
2144
        if err != nil {
×
2145
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2146
                        err)
×
2147
        }
×
2148

2149
        return edge, policy1, policy2, nil
×
2150
}
2151

2152
// HasChannelEdge returns true if the database knows of a channel edge with the
2153
// passed channel ID, and false otherwise. If an edge with that ID is found
2154
// within the graph, then two time stamps representing the last time the edge
2155
// was updated for both directed edges are returned along with the boolean. If
2156
// it is not found, then the zombie index is checked and its result is returned
2157
// as the second boolean.
2158
//
2159
// NOTE: part of the Store interface.
2160
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2161
        bool, error) {
×
2162

×
2163
        ctx := context.TODO()
×
2164

×
2165
        var (
×
2166
                exists          bool
×
2167
                isZombie        bool
×
2168
                node1LastUpdate time.Time
×
2169
                node2LastUpdate time.Time
×
2170
        )
×
2171

×
2172
        // We'll query the cache with the shared lock held to allow multiple
×
2173
        // readers to access values in the cache concurrently if they exist.
×
2174
        s.cacheMu.RLock()
×
2175
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2176
                s.cacheMu.RUnlock()
×
2177
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2178
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2179
                exists, isZombie = entry.flags.unpack()
×
2180

×
2181
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2182
        }
×
2183
        s.cacheMu.RUnlock()
×
2184

×
2185
        s.cacheMu.Lock()
×
2186
        defer s.cacheMu.Unlock()
×
2187

×
2188
        // The item was not found with the shared lock, so we'll acquire the
×
2189
        // exclusive lock and check the cache again in case another method added
×
2190
        // the entry to the cache while no lock was held.
×
2191
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2192
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2193
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2194
                exists, isZombie = entry.flags.unpack()
×
2195

×
2196
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2197
        }
×
2198

2199
        chanIDB := channelIDToBytes(chanID)
×
2200
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2201
                channel, err := db.GetChannelBySCID(
×
2202
                        ctx, sqlc.GetChannelBySCIDParams{
×
2203
                                Scid:    chanIDB,
×
2204
                                Version: int16(lnwire.GossipVersion1),
×
2205
                        },
×
2206
                )
×
2207
                if errors.Is(err, sql.ErrNoRows) {
×
2208
                        // Check if it is a zombie channel.
×
2209
                        isZombie, err = db.IsZombieChannel(
×
2210
                                ctx, sqlc.IsZombieChannelParams{
×
2211
                                        Scid:    chanIDB,
×
2212
                                        Version: int16(lnwire.GossipVersion1),
×
2213
                                },
×
2214
                        )
×
2215
                        if err != nil {
×
2216
                                return fmt.Errorf("could not check if channel "+
×
2217
                                        "is zombie: %w", err)
×
2218
                        }
×
2219

2220
                        return nil
×
2221
                } else if err != nil {
×
2222
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2223
                }
×
2224

2225
                exists = true
×
2226

×
2227
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2228
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2229
                                Version:   int16(lnwire.GossipVersion1),
×
2230
                                ChannelID: channel.ID,
×
2231
                                NodeID:    channel.NodeID1,
×
2232
                        },
×
2233
                )
×
2234
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2235
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2236
                                err)
×
2237
                } else if err == nil {
×
2238
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2239
                }
×
2240

2241
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2242
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2243
                                Version:   int16(lnwire.GossipVersion1),
×
2244
                                ChannelID: channel.ID,
×
2245
                                NodeID:    channel.NodeID2,
×
2246
                        },
×
2247
                )
×
2248
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2249
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2250
                                err)
×
2251
                } else if err == nil {
×
2252
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2253
                }
×
2254

2255
                return nil
×
2256
        }, sqldb.NoOpReset)
2257
        if err != nil {
×
2258
                return time.Time{}, time.Time{}, false, false,
×
2259
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2260
        }
×
2261

2262
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2263
                upd1Time: node1LastUpdate.Unix(),
×
2264
                upd2Time: node2LastUpdate.Unix(),
×
2265
                flags:    packRejectFlags(exists, isZombie),
×
2266
        })
×
2267

×
2268
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2269
}
2270

2271
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2272
// passed channel point (outpoint). If the passed channel doesn't exist within
2273
// the database, then ErrEdgeNotFound is returned.
2274
//
2275
// NOTE: part of the Store interface.
2276
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2277
        var (
×
2278
                ctx       = context.TODO()
×
2279
                channelID uint64
×
2280
        )
×
2281
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2282
                chanID, err := db.GetSCIDByOutpoint(
×
2283
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2284
                                Outpoint: chanPoint.String(),
×
2285
                                Version:  int16(lnwire.GossipVersion1),
×
2286
                        },
×
2287
                )
×
2288
                if errors.Is(err, sql.ErrNoRows) {
×
2289
                        return ErrEdgeNotFound
×
2290
                } else if err != nil {
×
2291
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2292
                                err)
×
2293
                }
×
2294

2295
                channelID = byteOrder.Uint64(chanID)
×
2296

×
2297
                return nil
×
2298
        }, sqldb.NoOpReset)
2299
        if err != nil {
×
2300
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2301
        }
×
2302

2303
        return channelID, nil
×
2304
}
2305

2306
// IsPublicNode is a helper method that determines whether the node with the
2307
// given public key is seen as a public node in the graph from the graph's
2308
// source node's point of view.
2309
//
2310
// NOTE: part of the Store interface.
2311
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2312
        ctx := context.TODO()
×
2313

×
2314
        var isPublic bool
×
2315
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2316
                var err error
×
2317
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2318

×
2319
                return err
×
2320
        }, sqldb.NoOpReset)
×
2321
        if err != nil {
×
2322
                return false, fmt.Errorf("unable to check if node is "+
×
2323
                        "public: %w", err)
×
2324
        }
×
2325

2326
        return isPublic, nil
×
2327
}
2328

2329
// FetchChanInfos returns the set of channel edges that correspond to the passed
2330
// channel ID's. If an edge is the query is unknown to the database, it will
2331
// skipped and the result will contain only those edges that exist at the time
2332
// of the query. This can be used to respond to peer queries that are seeking to
2333
// fill in gaps in their view of the channel graph.
2334
//
2335
// NOTE: part of the Store interface.
2336
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2337
        var (
×
2338
                ctx   = context.TODO()
×
2339
                edges = make(map[uint64]ChannelEdge)
×
2340
        )
×
2341
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2342
                // First, collect all channel rows.
×
2343
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2344
                chanCallBack := func(ctx context.Context,
×
2345
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2346

×
2347
                        channelRows = append(channelRows, row)
×
2348
                        return nil
×
2349
                }
×
2350

2351
                err := s.forEachChanWithPoliciesInSCIDList(
×
2352
                        ctx, db, chanCallBack, chanIDs,
×
2353
                )
×
2354
                if err != nil {
×
2355
                        return err
×
2356
                }
×
2357

2358
                if len(channelRows) == 0 {
×
2359
                        return nil
×
2360
                }
×
2361

2362
                // Batch build all channel edges.
2363
                chans, err := batchBuildChannelEdges(
×
2364
                        ctx, s.cfg, db, channelRows,
×
2365
                )
×
2366
                if err != nil {
×
2367
                        return fmt.Errorf("unable to build channel edges: %w",
×
2368
                                err)
×
2369
                }
×
2370

2371
                for _, c := range chans {
×
2372
                        edges[c.Info.ChannelID] = c
×
2373
                }
×
2374

2375
                return err
×
2376
        }, func() {
×
2377
                clear(edges)
×
2378
        })
×
2379
        if err != nil {
×
2380
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2381
        }
×
2382

2383
        res := make([]ChannelEdge, 0, len(edges))
×
2384
        for _, chanID := range chanIDs {
×
2385
                edge, ok := edges[chanID]
×
2386
                if !ok {
×
2387
                        continue
×
2388
                }
2389

2390
                res = append(res, edge)
×
2391
        }
2392

2393
        return res, nil
×
2394
}
2395

2396
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2397
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2398
// channels in a paginated manner.
2399
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2400
        db SQLQueries, cb func(ctx context.Context,
2401
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2402
        chanIDs []uint64) error {
×
2403

×
2404
        queryWrapper := func(ctx context.Context,
×
2405
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2406
                error) {
×
2407

×
2408
                return db.GetChannelsBySCIDWithPolicies(
×
2409
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2410
                                Version: int16(lnwire.GossipVersion1),
×
2411
                                Scids:   scids,
×
2412
                        },
×
2413
                )
×
2414
        }
×
2415

2416
        return sqldb.ExecuteBatchQuery(
×
2417
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2418
                cb,
×
2419
        )
×
2420
}
2421

2422
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2423
// ID's that we don't know and are not known zombies of the passed set. In other
2424
// words, we perform a set difference of our set of chan ID's and the ones
2425
// passed in. This method can be used by callers to determine the set of
2426
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2427
// known zombies is also returned.
2428
//
2429
// NOTE: part of the Store interface.
2430
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2431
        []ChannelUpdateInfo, error) {
×
2432

×
2433
        var (
×
2434
                ctx          = context.TODO()
×
2435
                newChanIDs   []uint64
×
2436
                knownZombies []ChannelUpdateInfo
×
2437
                infoLookup   = make(
×
2438
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2439
                )
×
2440
        )
×
2441

×
2442
        // We first build a lookup map of the channel ID's to the
×
2443
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2444
        // already know about.
×
2445
        for _, chanInfo := range chansInfo {
×
2446
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2447
        }
×
2448

2449
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2450
                // The call-back function deletes known channels from
×
2451
                // infoLookup, so that we can later check which channels are
×
2452
                // zombies by only looking at the remaining channels in the set.
×
2453
                cb := func(ctx context.Context,
×
2454
                        channel sqlc.GraphChannel) error {
×
2455

×
2456
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2457

×
2458
                        return nil
×
2459
                }
×
2460

2461
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2462
                if err != nil {
×
2463
                        return fmt.Errorf("unable to iterate through "+
×
2464
                                "channels: %w", err)
×
2465
                }
×
2466

2467
                // We want to ensure that we deal with the channels in the
2468
                // same order that they were passed in, so we iterate over the
2469
                // original chansInfo slice and then check if that channel is
2470
                // still in the infoLookup map.
2471
                for _, chanInfo := range chansInfo {
×
2472
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2473
                        if _, ok := infoLookup[channelID]; !ok {
×
2474
                                continue
×
2475
                        }
2476

2477
                        isZombie, err := db.IsZombieChannel(
×
2478
                                ctx, sqlc.IsZombieChannelParams{
×
2479
                                        Scid:    channelIDToBytes(channelID),
×
2480
                                        Version: int16(lnwire.GossipVersion1),
×
2481
                                },
×
2482
                        )
×
2483
                        if err != nil {
×
2484
                                return fmt.Errorf("unable to fetch zombie "+
×
2485
                                        "channel: %w", err)
×
2486
                        }
×
2487

2488
                        if isZombie {
×
2489
                                knownZombies = append(knownZombies, chanInfo)
×
2490

×
2491
                                continue
×
2492
                        }
2493

2494
                        newChanIDs = append(newChanIDs, channelID)
×
2495
                }
2496

2497
                return nil
×
2498
        }, func() {
×
2499
                newChanIDs = nil
×
2500
                knownZombies = nil
×
2501
                // Rebuild the infoLookup map in case of a rollback.
×
2502
                for _, chanInfo := range chansInfo {
×
2503
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2504
                        infoLookup[scid] = chanInfo
×
2505
                }
×
2506
        })
2507
        if err != nil {
×
2508
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2509
        }
×
2510

2511
        return newChanIDs, knownZombies, nil
×
2512
}
2513

2514
// forEachChanInSCIDList is a helper method that executes a paged query
2515
// against the database to fetch all channels that match the passed
2516
// ChannelUpdateInfo slice. The callback function is called for each channel
2517
// that is found.
2518
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2519
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2520
        chansInfo []ChannelUpdateInfo) error {
×
2521

×
2522
        queryWrapper := func(ctx context.Context,
×
2523
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2524

×
2525
                return db.GetChannelsBySCIDs(
×
2526
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2527
                                Version: int16(lnwire.GossipVersion1),
×
2528
                                Scids:   scids,
×
2529
                        },
×
2530
                )
×
2531
        }
×
2532

2533
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2534
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2535

×
2536
                return channelIDToBytes(channelID)
×
2537
        }
×
2538

2539
        return sqldb.ExecuteBatchQuery(
×
2540
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2541
                cb,
×
2542
        )
×
2543
}
2544

2545
// PruneGraphNodes is a garbage collection method which attempts to prune out
2546
// any nodes from the channel graph that are currently unconnected. This ensure
2547
// that we only maintain a graph of reachable nodes. In the event that a pruned
2548
// node gains more channels, it will be re-added back to the graph.
2549
//
2550
// NOTE: this prunes nodes across protocol versions. It will never prune the
2551
// source nodes.
2552
//
2553
// NOTE: part of the Store interface.
2554
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2555
        var ctx = context.TODO()
×
2556

×
2557
        var prunedNodes []route.Vertex
×
2558
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2559
                var err error
×
2560
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2561

×
2562
                return err
×
2563
        }, func() {
×
2564
                prunedNodes = nil
×
2565
        })
×
2566
        if err != nil {
×
2567
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2568
        }
×
2569

2570
        return prunedNodes, nil
×
2571
}
2572

2573
// PruneGraph prunes newly closed channels from the channel graph in response
2574
// to a new block being solved on the network. Any transactions which spend the
2575
// funding output of any known channels within he graph will be deleted.
2576
// Additionally, the "prune tip", or the last block which has been used to
2577
// prune the graph is stored so callers can ensure the graph is fully in sync
2578
// with the current UTXO state. A slice of channels that have been closed by
2579
// the target block along with any pruned nodes are returned if the function
2580
// succeeds without error.
2581
//
2582
// NOTE: part of the Store interface.
2583
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2584
        blockHash *chainhash.Hash, blockHeight uint32) (
2585
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2586

×
2587
        ctx := context.TODO()
×
2588

×
2589
        s.cacheMu.Lock()
×
2590
        defer s.cacheMu.Unlock()
×
2591

×
2592
        var (
×
2593
                closedChans []*models.ChannelEdgeInfo
×
2594
                prunedNodes []route.Vertex
×
2595
        )
×
2596
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2597
                // First, collect all channel rows that need to be pruned.
×
2598
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2599
                channelCallback := func(ctx context.Context,
×
2600
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2601

×
2602
                        channelRows = append(channelRows, row)
×
2603

×
2604
                        return nil
×
2605
                }
×
2606

2607
                err := s.forEachChanInOutpoints(
×
2608
                        ctx, db, spentOutputs, channelCallback,
×
2609
                )
×
2610
                if err != nil {
×
2611
                        return fmt.Errorf("unable to fetch channels by "+
×
2612
                                "outpoints: %w", err)
×
2613
                }
×
2614

2615
                if len(channelRows) == 0 {
×
2616
                        // There are no channels to prune. So we can exit early
×
2617
                        // after updating the prune log.
×
2618
                        err = db.UpsertPruneLogEntry(
×
2619
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2620
                                        BlockHash:   blockHash[:],
×
2621
                                        BlockHeight: int64(blockHeight),
×
2622
                                },
×
2623
                        )
×
2624
                        if err != nil {
×
2625
                                return fmt.Errorf("unable to insert prune log "+
×
2626
                                        "entry: %w", err)
×
2627
                        }
×
2628

2629
                        return nil
×
2630
                }
2631

2632
                // Batch build all channel edges for pruning.
2633
                var chansToDelete []int64
×
2634
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2635
                        ctx, s.cfg, db, channelRows,
×
2636
                )
×
2637
                if err != nil {
×
2638
                        return err
×
2639
                }
×
2640

2641
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2642
                if err != nil {
×
2643
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2644
                }
×
2645

2646
                err = db.UpsertPruneLogEntry(
×
2647
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2648
                                BlockHash:   blockHash[:],
×
2649
                                BlockHeight: int64(blockHeight),
×
2650
                        },
×
2651
                )
×
2652
                if err != nil {
×
2653
                        return fmt.Errorf("unable to insert prune log "+
×
2654
                                "entry: %w", err)
×
2655
                }
×
2656

2657
                // Now that we've pruned some channels, we'll also prune any
2658
                // nodes that no longer have any channels.
2659
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2660
                if err != nil {
×
2661
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2662
                                err)
×
2663
                }
×
2664

2665
                return nil
×
2666
        }, func() {
×
2667
                prunedNodes = nil
×
2668
                closedChans = nil
×
2669
        })
×
2670
        if err != nil {
×
2671
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2672
        }
×
2673

2674
        for _, channel := range closedChans {
×
2675
                s.rejectCache.remove(channel.ChannelID)
×
2676
                s.chanCache.remove(channel.ChannelID)
×
2677
        }
×
2678

2679
        return closedChans, prunedNodes, nil
×
2680
}
2681

2682
// forEachChanInOutpoints is a helper function that executes a paginated
2683
// query to fetch channels by their outpoints and applies the given call-back
2684
// to each.
2685
//
2686
// NOTE: this fetches channels for all protocol versions.
2687
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2688
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2689
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2690

×
2691
        // Create a wrapper that uses the transaction's db instance to execute
×
2692
        // the query.
×
2693
        queryWrapper := func(ctx context.Context,
×
2694
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2695
                error) {
×
2696

×
2697
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2698
        }
×
2699

2700
        // Define the conversion function from Outpoint to string.
2701
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2702
                return outpoint.String()
×
2703
        }
×
2704

2705
        return sqldb.ExecuteBatchQuery(
×
2706
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2707
                queryWrapper, cb,
×
2708
        )
×
2709
}
2710

2711
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2712
        dbIDs []int64) error {
×
2713

×
2714
        // Create a wrapper that uses the transaction's db instance to execute
×
2715
        // the query.
×
2716
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2717
                return nil, db.DeleteChannels(ctx, ids)
×
2718
        }
×
2719

2720
        idConverter := func(id int64) int64 {
×
2721
                return id
×
2722
        }
×
2723

2724
        return sqldb.ExecuteBatchQuery(
×
2725
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2726
                queryWrapper, func(ctx context.Context, _ any) error {
×
2727
                        return nil
×
2728
                },
×
2729
        )
2730
}
2731

2732
// ChannelView returns the verifiable edge information for each active channel
2733
// within the known channel graph. The set of UTXOs (along with their scripts)
2734
// returned are the ones that need to be watched on chain to detect channel
2735
// closes on the resident blockchain.
2736
//
2737
// NOTE: part of the Store interface.
2738
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2739
        var (
×
2740
                ctx        = context.TODO()
×
2741
                edgePoints []EdgePoint
×
2742
        )
×
2743

×
2744
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2745
                handleChannel := func(_ context.Context,
×
2746
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2747

×
2748
                        pkScript, err := genMultiSigP2WSH(
×
2749
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2750
                        )
×
2751
                        if err != nil {
×
2752
                                return err
×
2753
                        }
×
2754

2755
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2756
                        if err != nil {
×
2757
                                return err
×
2758
                        }
×
2759

2760
                        edgePoints = append(edgePoints, EdgePoint{
×
2761
                                FundingPkScript: pkScript,
×
2762
                                OutPoint:        *op,
×
2763
                        })
×
2764

×
2765
                        return nil
×
2766
                }
2767

2768
                queryFunc := func(ctx context.Context, lastID int64,
×
2769
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2770

×
2771
                        return db.ListChannelsPaginated(
×
2772
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2773
                                        Version: int16(lnwire.GossipVersion1),
×
2774
                                        ID:      lastID,
×
2775
                                        Limit:   limit,
×
2776
                                },
×
2777
                        )
×
2778
                }
×
2779

2780
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2781
                        return row.ID
×
2782
                }
×
2783

2784
                return sqldb.ExecutePaginatedQuery(
×
2785
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2786
                        extractCursor, handleChannel,
×
2787
                )
×
2788
        }, func() {
×
2789
                edgePoints = nil
×
2790
        })
×
2791
        if err != nil {
×
2792
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2793
        }
×
2794

2795
        return edgePoints, nil
×
2796
}
2797

2798
// PruneTip returns the block height and hash of the latest block that has been
2799
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2800
// to tell if the graph is currently in sync with the current best known UTXO
2801
// state.
2802
//
2803
// NOTE: part of the Store interface.
2804
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2805
        var (
×
2806
                ctx       = context.TODO()
×
2807
                tipHash   chainhash.Hash
×
2808
                tipHeight uint32
×
2809
        )
×
2810
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2811
                pruneTip, err := db.GetPruneTip(ctx)
×
2812
                if errors.Is(err, sql.ErrNoRows) {
×
2813
                        return ErrGraphNeverPruned
×
2814
                } else if err != nil {
×
2815
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2816
                }
×
2817

2818
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2819
                tipHeight = uint32(pruneTip.BlockHeight)
×
2820

×
2821
                return nil
×
2822
        }, sqldb.NoOpReset)
2823
        if err != nil {
×
2824
                return nil, 0, err
×
2825
        }
×
2826

2827
        return &tipHash, tipHeight, nil
×
2828
}
2829

2830
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2831
//
2832
// NOTE: this prunes nodes across protocol versions. It will never prune the
2833
// source nodes.
2834
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2835
        db SQLQueries) ([]route.Vertex, error) {
×
2836

×
2837
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2838
        if err != nil {
×
2839
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2840
                        "nodes: %w", err)
×
2841
        }
×
2842

2843
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2844
        for i, nodeKey := range nodeKeys {
×
2845
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2846
                if err != nil {
×
2847
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2848
                                "from bytes: %w", err)
×
2849
                }
×
2850

2851
                prunedNodes[i] = pub
×
2852
        }
2853

2854
        return prunedNodes, nil
×
2855
}
2856

2857
// DisconnectBlockAtHeight is used to indicate that the block specified
2858
// by the passed height has been disconnected from the main chain. This
2859
// will "rewind" the graph back to the height below, deleting channels
2860
// that are no longer confirmed from the graph. The prune log will be
2861
// set to the last prune height valid for the remaining chain.
2862
// Channels that were removed from the graph resulting from the
2863
// disconnected block are returned.
2864
//
2865
// NOTE: part of the Store interface.
2866
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2867
        []*models.ChannelEdgeInfo, error) {
×
2868

×
2869
        ctx := context.TODO()
×
2870

×
2871
        var (
×
2872
                // Every channel having a ShortChannelID starting at 'height'
×
2873
                // will no longer be confirmed.
×
2874
                startShortChanID = lnwire.ShortChannelID{
×
2875
                        BlockHeight: height,
×
2876
                }
×
2877

×
2878
                // Delete everything after this height from the db up until the
×
2879
                // SCID alias range.
×
2880
                endShortChanID = aliasmgr.StartingAlias
×
2881

×
2882
                removedChans []*models.ChannelEdgeInfo
×
2883

×
2884
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2885
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2886
        )
×
2887

×
2888
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2889
                rows, err := db.GetChannelsBySCIDRange(
×
2890
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2891
                                StartScid: chanIDStart,
×
2892
                                EndScid:   chanIDEnd,
×
2893
                        },
×
2894
                )
×
2895
                if err != nil {
×
2896
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2897
                }
×
2898

2899
                if len(rows) == 0 {
×
2900
                        // No channels to disconnect, but still clean up prune
×
2901
                        // log.
×
2902
                        return db.DeletePruneLogEntriesInRange(
×
2903
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2904
                                        StartHeight: int64(height),
×
2905
                                        EndHeight: int64(
×
2906
                                                endShortChanID.BlockHeight,
×
2907
                                        ),
×
2908
                                },
×
2909
                        )
×
2910
                }
×
2911

2912
                // Batch build all channel edges for disconnection.
2913
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2914
                        ctx, s.cfg, db, rows,
×
2915
                )
×
2916
                if err != nil {
×
2917
                        return err
×
2918
                }
×
2919

2920
                removedChans = channelEdges
×
2921

×
2922
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2923
                if err != nil {
×
2924
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2925
                }
×
2926

2927
                return db.DeletePruneLogEntriesInRange(
×
2928
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2929
                                StartHeight: int64(height),
×
2930
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2931
                        },
×
2932
                )
×
2933
        }, func() {
×
2934
                removedChans = nil
×
2935
        })
×
2936
        if err != nil {
×
2937
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2938
                        "height: %w", err)
×
2939
        }
×
2940

2941
        for _, channel := range removedChans {
×
2942
                s.rejectCache.remove(channel.ChannelID)
×
2943
                s.chanCache.remove(channel.ChannelID)
×
2944
        }
×
2945

2946
        return removedChans, nil
×
2947
}
2948

2949
// AddEdgeProof sets the proof of an existing edge in the graph database.
2950
//
2951
// NOTE: part of the Store interface.
2952
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2953
        proof *models.ChannelAuthProof) error {
×
2954

×
2955
        var (
×
2956
                ctx       = context.TODO()
×
2957
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2958
        )
×
2959

×
2960
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2961
                res, err := db.AddV1ChannelProof(
×
2962
                        ctx, sqlc.AddV1ChannelProofParams{
×
2963
                                Scid:              scidBytes,
×
2964
                                Node1Signature:    proof.NodeSig1Bytes,
×
2965
                                Node2Signature:    proof.NodeSig2Bytes,
×
2966
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2967
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2968
                        },
×
2969
                )
×
2970
                if err != nil {
×
2971
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2972
                }
×
2973

2974
                n, err := res.RowsAffected()
×
2975
                if err != nil {
×
2976
                        return err
×
2977
                }
×
2978

2979
                if n == 0 {
×
2980
                        return fmt.Errorf("no rows affected when adding edge "+
×
2981
                                "proof for SCID %v", scid)
×
2982
                } else if n > 1 {
×
2983
                        return fmt.Errorf("multiple rows affected when adding "+
×
2984
                                "edge proof for SCID %v: %d rows affected",
×
2985
                                scid, n)
×
2986
                }
×
2987

2988
                return nil
×
2989
        }, sqldb.NoOpReset)
2990
        if err != nil {
×
2991
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2992
        }
×
2993

2994
        return nil
×
2995
}
2996

2997
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2998
// that we can ignore channel announcements that we know to be closed without
2999
// having to validate them and fetch a block.
3000
//
3001
// NOTE: part of the Store interface.
3002
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
3003
        var (
×
3004
                ctx     = context.TODO()
×
3005
                chanIDB = channelIDToBytes(scid.ToUint64())
×
3006
        )
×
3007

×
3008
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
3009
                return db.InsertClosedChannel(ctx, chanIDB)
×
3010
        }, sqldb.NoOpReset)
×
3011
}
3012

3013
// IsClosedScid checks whether a channel identified by the passed in scid is
3014
// closed. This helps avoid having to perform expensive validation checks.
3015
//
3016
// NOTE: part of the Store interface.
3017
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
3018
        var (
×
3019
                ctx      = context.TODO()
×
3020
                isClosed bool
×
3021
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
3022
        )
×
3023
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3024
                var err error
×
3025
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3026
                if err != nil {
×
3027
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3028
                                err)
×
3029
                }
×
3030

3031
                return nil
×
3032
        }, sqldb.NoOpReset)
3033
        if err != nil {
×
3034
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3035
                        err)
×
3036
        }
×
3037

3038
        return isClosed, nil
×
3039
}
3040

3041
// GraphSession will provide the call-back with access to a NodeTraverser
3042
// instance which can be used to perform queries against the channel graph.
3043
//
3044
// NOTE: part of the Store interface.
3045
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3046
        reset func()) error {
×
3047

×
3048
        var ctx = context.TODO()
×
3049

×
3050
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3051
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3052
        }, reset)
×
3053
}
3054

3055
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3056
// read only transaction for a consistent view of the graph.
3057
type sqlNodeTraverser struct {
3058
        db    SQLQueries
3059
        chain chainhash.Hash
3060
}
3061

3062
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3063
// NodeTraverser interface.
3064
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3065

3066
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3067
func newSQLNodeTraverser(db SQLQueries,
3068
        chain chainhash.Hash) *sqlNodeTraverser {
×
3069

×
3070
        return &sqlNodeTraverser{
×
3071
                db:    db,
×
3072
                chain: chain,
×
3073
        }
×
3074
}
×
3075

3076
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3077
// node.
3078
//
3079
// NOTE: Part of the NodeTraverser interface.
3080
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3081
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3082

×
3083
        ctx := context.TODO()
×
3084

×
3085
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3086
}
×
3087

3088
// FetchNodeFeatures returns the features of the given node. If the node is
3089
// unknown, assume no additional features are supported.
3090
//
3091
// NOTE: Part of the NodeTraverser interface.
3092
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3093
        *lnwire.FeatureVector, error) {
×
3094

×
3095
        ctx := context.TODO()
×
3096

×
NEW
3097
        return fetchNodeFeatures(ctx, s.db, lnwire.GossipVersion1, nodePub)
×
3098
}
×
3099

3100
// forEachNodeDirectedChannel iterates through all channels of a given
3101
// node, executing the passed callback on the directed edge representing the
3102
// channel and its incoming policy. If the node is not found, no error is
3103
// returned.
3104
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3105
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3106

×
3107
        toNodeCallback := func() route.Vertex {
×
3108
                return nodePub
×
3109
        }
×
3110

3111
        dbID, err := db.GetNodeIDByPubKey(
×
3112
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3113
                        Version: int16(lnwire.GossipVersion1),
×
3114
                        PubKey:  nodePub[:],
×
3115
                },
×
3116
        )
×
3117
        if errors.Is(err, sql.ErrNoRows) {
×
3118
                return nil
×
3119
        } else if err != nil {
×
3120
                return fmt.Errorf("unable to fetch node: %w", err)
×
3121
        }
×
3122

3123
        rows, err := db.ListChannelsByNodeID(
×
3124
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3125
                        Version: int16(lnwire.GossipVersion1),
×
3126
                        NodeID1: dbID,
×
3127
                },
×
3128
        )
×
3129
        if err != nil {
×
3130
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3131
        }
×
3132

3133
        // Exit early if there are no channels for this node so we don't
3134
        // do the unnecessary feature fetching.
3135
        if len(rows) == 0 {
×
3136
                return nil
×
3137
        }
×
3138

3139
        features, err := getNodeFeatures(ctx, db, dbID)
×
3140
        if err != nil {
×
3141
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3142
        }
×
3143

3144
        for _, row := range rows {
×
3145
                node1, node2, err := buildNodeVertices(
×
3146
                        row.Node1Pubkey, row.Node2Pubkey,
×
3147
                )
×
3148
                if err != nil {
×
3149
                        return fmt.Errorf("unable to build node vertices: %w",
×
3150
                                err)
×
3151
                }
×
3152

3153
                edge := buildCacheableChannelInfo(
×
3154
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3155
                        node1, node2,
×
3156
                )
×
3157

×
3158
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3159
                if err != nil {
×
3160
                        return err
×
3161
                }
×
3162

3163
                p1, p2, err := buildCachedChanPolicies(
×
3164
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3165
                )
×
3166
                if err != nil {
×
3167
                        return err
×
3168
                }
×
3169

3170
                // Determine the outgoing and incoming policy for this
3171
                // channel and node combo.
3172
                outPolicy, inPolicy := p1, p2
×
3173
                if p1 != nil && node2 == nodePub {
×
3174
                        outPolicy, inPolicy = p2, p1
×
3175
                } else if p2 != nil && node1 != nodePub {
×
3176
                        outPolicy, inPolicy = p2, p1
×
3177
                }
×
3178

3179
                var cachedInPolicy *models.CachedEdgePolicy
×
3180
                if inPolicy != nil {
×
3181
                        cachedInPolicy = inPolicy
×
3182
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3183
                        cachedInPolicy.ToNodeFeatures = features
×
3184
                }
×
3185

3186
                directedChannel := &DirectedChannel{
×
3187
                        ChannelID:    edge.ChannelID,
×
3188
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3189
                        OtherNode:    edge.NodeKey2Bytes,
×
3190
                        Capacity:     edge.Capacity,
×
3191
                        OutPolicySet: outPolicy != nil,
×
3192
                        InPolicy:     cachedInPolicy,
×
3193
                }
×
3194
                if outPolicy != nil {
×
3195
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3196
                                directedChannel.InboundFee = fee
×
3197
                        })
×
3198
                }
3199

3200
                if nodePub == edge.NodeKey2Bytes {
×
3201
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3202
                }
×
3203

3204
                if err := cb(directedChannel); err != nil {
×
3205
                        return err
×
3206
                }
×
3207
        }
3208

3209
        return nil
×
3210
}
3211

3212
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3213
// and executes the provided callback for each node. It does so via pagination
3214
// along with batch loading of the node feature bits.
3215
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3216
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3217
                features *lnwire.FeatureVector) error) error {
×
3218

×
3219
        handleNode := func(_ context.Context,
×
3220
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3221
                featureBits map[int64][]int) error {
×
3222

×
3223
                fv := lnwire.EmptyFeatureVector()
×
3224
                if features, exists := featureBits[dbNode.ID]; exists {
×
3225
                        for _, bit := range features {
×
3226
                                fv.Set(lnwire.FeatureBit(bit))
×
3227
                        }
×
3228
                }
3229

3230
                var pub route.Vertex
×
3231
                copy(pub[:], dbNode.PubKey)
×
3232

×
3233
                return processNode(dbNode.ID, pub, fv)
×
3234
        }
3235

3236
        queryFunc := func(ctx context.Context, lastID int64,
×
3237
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3238

×
3239
                return db.ListNodeIDsAndPubKeys(
×
3240
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3241
                                Version: int16(lnwire.GossipVersion1),
×
3242
                                ID:      lastID,
×
3243
                                Limit:   limit,
×
3244
                        },
×
3245
                )
×
3246
        }
×
3247

3248
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3249
                return row.ID
×
3250
        }
×
3251

3252
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3253
                return node.ID, nil
×
3254
        }
×
3255

3256
        batchQueryFunc := func(ctx context.Context,
×
3257
                nodeIDs []int64) (map[int64][]int, error) {
×
3258

×
3259
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3260
        }
×
3261

3262
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3263
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3264
                batchQueryFunc, handleNode,
×
3265
        )
×
3266
}
3267

3268
// forEachNodeChannel iterates through all channels of a node, executing
3269
// the passed callback on each. The call-back is provided with the channel's
3270
// edge information, the outgoing policy and the incoming policy for the
3271
// channel and node combo.
3272
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3273
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3274
                *models.ChannelEdgePolicy,
3275
                *models.ChannelEdgePolicy) error) error {
×
3276

×
3277
        // Get all the V1 channels for this node.
×
3278
        rows, err := db.ListChannelsByNodeID(
×
3279
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3280
                        Version: int16(lnwire.GossipVersion1),
×
3281
                        NodeID1: id,
×
3282
                },
×
3283
        )
×
3284
        if err != nil {
×
3285
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3286
        }
×
3287

3288
        // Collect all the channel and policy IDs.
3289
        var (
×
3290
                chanIDs   = make([]int64, 0, len(rows))
×
3291
                policyIDs = make([]int64, 0, 2*len(rows))
×
3292
        )
×
3293
        for _, row := range rows {
×
3294
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3295

×
3296
                if row.Policy1ID.Valid {
×
3297
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3298
                }
×
3299
                if row.Policy2ID.Valid {
×
3300
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3301
                }
×
3302
        }
3303

3304
        batchData, err := batchLoadChannelData(
×
3305
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3306
        )
×
3307
        if err != nil {
×
3308
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3309
        }
×
3310

3311
        // Call the call-back for each channel and its known policies.
3312
        for _, row := range rows {
×
3313
                node1, node2, err := buildNodeVertices(
×
3314
                        row.Node1Pubkey, row.Node2Pubkey,
×
3315
                )
×
3316
                if err != nil {
×
3317
                        return fmt.Errorf("unable to build node vertices: %w",
×
3318
                                err)
×
3319
                }
×
3320

3321
                edge, err := buildEdgeInfoWithBatchData(
×
3322
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3323
                        batchData,
×
3324
                )
×
3325
                if err != nil {
×
3326
                        return fmt.Errorf("unable to build channel info: %w",
×
3327
                                err)
×
3328
                }
×
3329

3330
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3331
                if err != nil {
×
3332
                        return fmt.Errorf("unable to extract channel "+
×
3333
                                "policies: %w", err)
×
3334
                }
×
3335

3336
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3337
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3338
                )
×
3339
                if err != nil {
×
3340
                        return fmt.Errorf("unable to build channel "+
×
3341
                                "policies: %w", err)
×
3342
                }
×
3343

3344
                // Determine the outgoing and incoming policy for this
3345
                // channel and node combo.
3346
                p1ToNode := row.GraphChannel.NodeID2
×
3347
                p2ToNode := row.GraphChannel.NodeID1
×
3348
                outPolicy, inPolicy := p1, p2
×
3349
                if (p1 != nil && p1ToNode == id) ||
×
3350
                        (p2 != nil && p2ToNode != id) {
×
3351

×
3352
                        outPolicy, inPolicy = p2, p1
×
3353
                }
×
3354

3355
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3356
                        return err
×
3357
                }
×
3358
        }
3359

3360
        return nil
×
3361
}
3362

3363
// updateChanEdgePolicy upserts the channel policy info we have stored for
3364
// a channel we already know of.
3365
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3366
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3367
        error) {
×
3368

×
3369
        var (
×
3370
                node1Pub, node2Pub route.Vertex
×
3371
                isNode1            bool
×
3372
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3373
        )
×
3374

×
3375
        // Check that this edge policy refers to a channel that we already
×
3376
        // know of. We do this explicitly so that we can return the appropriate
×
3377
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3378
        // abort the transaction which would abort the entire batch.
×
3379
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3380
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3381
                        Scid:    chanIDB,
×
3382
                        Version: int16(lnwire.GossipVersion1),
×
3383
                },
×
3384
        )
×
3385
        if errors.Is(err, sql.ErrNoRows) {
×
3386
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3387
        } else if err != nil {
×
3388
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3389
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3390
        }
×
3391

3392
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3393
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3394

×
3395
        // Figure out which node this edge is from.
×
3396
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3397
        nodeID := dbChan.NodeID1
×
3398
        if !isNode1 {
×
3399
                nodeID = dbChan.NodeID2
×
3400
        }
×
3401

3402
        var (
×
3403
                inboundBase sql.NullInt64
×
3404
                inboundRate sql.NullInt64
×
3405
        )
×
3406
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3407
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3408
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3409
        })
×
3410

3411
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3412
                Version:     int16(lnwire.GossipVersion1),
×
3413
                ChannelID:   dbChan.ID,
×
3414
                NodeID:      nodeID,
×
3415
                Timelock:    int32(edge.TimeLockDelta),
×
3416
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3417
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3418
                MinHtlcMsat: int64(edge.MinHTLC),
×
3419
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3420
                Disabled: sql.NullBool{
×
3421
                        Valid: true,
×
3422
                        Bool:  edge.IsDisabled(),
×
3423
                },
×
3424
                MaxHtlcMsat: sql.NullInt64{
×
3425
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3426
                        Int64: int64(edge.MaxHTLC),
×
3427
                },
×
3428
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3429
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3430
                InboundBaseFeeMsat:      inboundBase,
×
3431
                InboundFeeRateMilliMsat: inboundRate,
×
3432
                Signature:               edge.SigBytes,
×
3433
        })
×
3434
        if err != nil {
×
3435
                return node1Pub, node2Pub, isNode1,
×
3436
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3437
        }
×
3438

3439
        // Convert the flat extra opaque data into a map of TLV types to
3440
        // values.
3441
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3442
        if err != nil {
×
3443
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3444
                        "marshal extra opaque data: %w", err)
×
3445
        }
×
3446

3447
        // Update the channel policy's extra signed fields.
3448
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3449
        if err != nil {
×
3450
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3451
                        "policy extra TLVs: %w", err)
×
3452
        }
×
3453

3454
        return node1Pub, node2Pub, isNode1, nil
×
3455
}
3456

3457
// getNodeByPubKey attempts to look up a target node by its public key.
3458
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3459
        v lnwire.GossipVersion, pubKey route.Vertex) (int64, *models.Node,
NEW
3460
        error) {
×
3461

×
3462
        dbNode, err := db.GetNodeByPubKey(
×
3463
                ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
3464
                        Version: int16(v),
×
3465
                        PubKey:  pubKey[:],
×
3466
                },
×
3467
        )
×
3468
        if errors.Is(err, sql.ErrNoRows) {
×
3469
                return 0, nil, ErrGraphNodeNotFound
×
3470
        } else if err != nil {
×
3471
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3472
        }
×
3473

3474
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3475
        if err != nil {
×
3476
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3477
        }
×
3478

3479
        return dbNode.ID, node, nil
×
3480
}
3481

3482
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3483
// provided parameters.
3484
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3485
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3486

×
3487
        return &models.CachedEdgeInfo{
×
3488
                ChannelID:     byteOrder.Uint64(scid),
×
3489
                NodeKey1Bytes: node1Pub,
×
3490
                NodeKey2Bytes: node2Pub,
×
3491
                Capacity:      btcutil.Amount(capacity),
×
3492
        }
×
3493
}
×
3494

3495
// buildNode constructs a Node instance from the given database node
3496
// record. The node's features, addresses and extra signed fields are also
3497
// fetched from the database and set on the node.
3498
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3499
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3500

×
3501
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3502
        if err != nil {
×
3503
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3504
                        err)
×
3505
        }
×
3506

3507
        return buildNodeWithBatchData(dbNode, data)
×
3508
}
3509

3510
// isKnownGossipVersion checks whether the provided gossip version is known
3511
// and supported.
NEW
3512
func isKnownGossipVersion(v lnwire.GossipVersion) bool {
×
NEW
3513
        switch v {
×
NEW
3514
        case lnwire.GossipVersion1:
×
NEW
3515
                return true
×
NEW
3516
        case lnwire.GossipVersion2:
×
NEW
3517
                return true
×
NEW
3518
        default:
×
NEW
3519
                return false
×
3520
        }
3521
}
3522

3523
// buildNodeWithBatchData builds a models.Node instance
3524
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3525
// features/addresses/extra fields, then the corresponding fields are expected
3526
// to be present in the batchNodeData.
3527
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3528
        batchData *batchNodeData) (*models.Node, error) {
×
3529

×
NEW
3530
        v := lnwire.GossipVersion(dbNode.Version)
×
NEW
3531

×
NEW
3532
        if !isKnownGossipVersion(v) {
×
NEW
3533
                return nil, fmt.Errorf("unknown node version: %d", v)
×
UNCOV
3534
        }
×
3535

NEW
3536
        pub, err := route.NewVertexFromBytes(dbNode.PubKey)
×
NEW
3537
        if err != nil {
×
NEW
3538
                return nil, fmt.Errorf("unable to parse pubkey: %w", err)
×
NEW
3539
        }
×
3540

NEW
3541
        node := models.NewShellNode(v, pub)
×
3542

×
3543
        if len(dbNode.Signature) == 0 {
×
3544
                return node, nil
×
3545
        }
×
3546

3547
        node.AuthSigBytes = dbNode.Signature
×
3548

×
3549
        if dbNode.Alias.Valid {
×
3550
                node.Alias = fn.Some(dbNode.Alias.String)
×
3551
        }
×
3552
        if dbNode.LastUpdate.Valid {
×
3553
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3554
        }
×
NEW
3555
        if dbNode.BlockHeight.Valid {
×
NEW
3556
                node.LastBlockHeight = uint32(dbNode.BlockHeight.Int64)
×
NEW
3557
        }
×
3558

3559
        if dbNode.Color.Valid {
×
3560
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3561
                if err != nil {
×
3562
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3563
                                err)
×
3564
                }
×
3565

3566
                node.Color = fn.Some(nodeColor)
×
3567
        }
3568

3569
        // Use preloaded features.
3570
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3571
                fv := lnwire.EmptyFeatureVector()
×
3572
                for _, bit := range features {
×
3573
                        fv.Set(lnwire.FeatureBit(bit))
×
3574
                }
×
3575
                node.Features = fv
×
3576
        }
3577

3578
        // Use preloaded addresses.
3579
        addresses, exists := batchData.addresses[dbNode.ID]
×
3580
        if exists && len(addresses) > 0 {
×
3581
                node.Addresses, err = buildNodeAddresses(addresses)
×
3582
                if err != nil {
×
3583
                        return nil, fmt.Errorf("unable to build addresses "+
×
3584
                                "for node(%d): %w", dbNode.ID, err)
×
3585
                }
×
3586
        }
3587

3588
        // Use preloaded extra fields.
3589
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
NEW
3590
                if v == lnwire.GossipVersion1 {
×
NEW
3591
                        records := lnwire.CustomRecords(extraFields)
×
NEW
3592
                        recs, err := records.Serialize()
×
NEW
3593
                        if err != nil {
×
NEW
3594
                                return nil, fmt.Errorf("unable to serialize "+
×
NEW
3595
                                        "extra signed fields: %w", err)
×
NEW
3596
                        }
×
3597

NEW
3598
                        if len(recs) != 0 {
×
NEW
3599
                                node.ExtraOpaqueData = recs
×
NEW
3600
                        }
×
NEW
3601
                } else if len(extraFields) > 0 {
×
NEW
3602
                        node.ExtraSignedFields = extraFields
×
UNCOV
3603
                }
×
3604
        }
3605

3606
        return node, nil
×
3607
}
3608

3609
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3610
// with the preloaded data, and executes the provided callback for each node.
3611
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3612
        db SQLQueries, nodes []sqlc.GraphNode,
3613
        cb func(dbID int64, node *models.Node) error) error {
×
3614

×
3615
        // Extract node IDs for batch loading.
×
3616
        nodeIDs := make([]int64, len(nodes))
×
3617
        for i, node := range nodes {
×
3618
                nodeIDs[i] = node.ID
×
3619
        }
×
3620

3621
        // Batch load all related data for this page.
3622
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3623
        if err != nil {
×
3624
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3625
        }
×
3626

3627
        for _, dbNode := range nodes {
×
3628
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3629
                if err != nil {
×
3630
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3631
                                dbNode.ID, err)
×
3632
                }
×
3633

3634
                if err := cb(dbNode.ID, node); err != nil {
×
3635
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3636
                                dbNode.ID, err)
×
3637
                }
×
3638
        }
3639

3640
        return nil
×
3641
}
3642

3643
// getNodeFeatures fetches the feature bits and constructs the feature vector
3644
// for a node with the given DB ID.
3645
func getNodeFeatures(ctx context.Context, db SQLQueries,
3646
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3647

×
3648
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3649
        if err != nil {
×
3650
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3651
                        nodeID, err)
×
3652
        }
×
3653

3654
        features := lnwire.EmptyFeatureVector()
×
3655
        for _, feature := range rows {
×
3656
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3657
        }
×
3658

3659
        return features, nil
×
3660
}
3661

3662
// upsertNode upserts the node record into the database. If the node already
3663
// exists, then the node's information is updated. If the node doesn't exist,
3664
// then a new node is created. The node's features, addresses and extra TLV
3665
// types are also updated. The node's DB ID is returned.
3666
func upsertNode(ctx context.Context, db SQLQueries,
3667
        node *models.Node) (int64, error) {
×
3668

×
NEW
3669
        if !isKnownGossipVersion(node.Version) {
×
NEW
3670
                return 0, fmt.Errorf("unknown gossip version: %d", node.Version)
×
NEW
3671
        }
×
3672

3673
        params := sqlc.UpsertNodeParams{
×
NEW
3674
                Version: int16(node.Version),
×
3675
                PubKey:  node.PubKeyBytes[:],
×
3676
        }
×
3677

×
3678
        if node.HaveAnnouncement() {
×
3679
                switch node.Version {
×
3680
                case lnwire.GossipVersion1:
×
3681
                        params.LastUpdate = sqldb.SQLInt64(
×
3682
                                node.LastUpdate.Unix(),
×
3683
                        )
×
3684

3685
                case lnwire.GossipVersion2:
×
NEW
3686
                        params.BlockHeight = sqldb.SQLInt64(
×
NEW
3687
                                int32(node.LastBlockHeight),
×
NEW
3688
                        )
×
3689

3690
                default:
×
3691
                        return 0, fmt.Errorf("unknown gossip version: %d",
×
3692
                                node.Version)
×
3693
                }
3694

3695
                node.Color.WhenSome(func(rgba color.RGBA) {
×
3696
                        params.Color = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
3697
                })
×
3698
                node.Alias.WhenSome(func(s string) {
×
3699
                        params.Alias = sqldb.SQLStrValid(s)
×
3700
                })
×
3701

3702
                params.Signature = node.AuthSigBytes
×
3703
        }
3704

3705
        nodeID, err := db.UpsertNode(ctx, params)
×
3706
        if err != nil {
×
3707
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3708
                        err)
×
3709
        }
×
3710

3711
        // We can exit here if we don't have the announcement yet.
3712
        if !node.HaveAnnouncement() {
×
3713
                return nodeID, nil
×
3714
        }
×
3715

3716
        // Update the node's features.
3717
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3718
        if err != nil {
×
3719
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3720
        }
×
3721

3722
        // Update the node's addresses.
3723
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3724
        if err != nil {
×
3725
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3726
        }
×
3727

3728
        // Convert the flat extra opaque data into a map of TLV types to
3729
        // values.
NEW
3730
        extra := node.ExtraSignedFields
×
NEW
3731
        if node.Version == lnwire.GossipVersion1 {
×
NEW
3732
                extra, err = marshalExtraOpaqueData(node.ExtraOpaqueData)
×
NEW
3733
                if err != nil {
×
NEW
3734
                        return 0, fmt.Errorf("unable to marshal extra opaque "+
×
NEW
3735
                                "data: %w", err)
×
NEW
3736
                }
×
3737
        }
3738

3739
        // Update the node's extra signed fields.
3740
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3741
        if err != nil {
×
3742
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3743
        }
×
3744

3745
        return nodeID, nil
×
3746
}
3747

3748
// upsertNodeFeatures updates the node's features node_features table. This
3749
// includes deleting any feature bits no longer present and inserting any new
3750
// feature bits. If the feature bit does not yet exist in the features table,
3751
// then an entry is created in that table first.
3752
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3753
        features *lnwire.FeatureVector) error {
×
3754

×
3755
        // Get any existing features for the node.
×
3756
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3757
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3758
                return err
×
3759
        }
×
3760

3761
        // Copy the nodes latest set of feature bits.
3762
        newFeatures := make(map[int32]struct{})
×
3763
        if features != nil {
×
3764
                for feature := range features.Features() {
×
3765
                        newFeatures[int32(feature)] = struct{}{}
×
3766
                }
×
3767
        }
3768

3769
        // For any current feature that already exists in the DB, remove it from
3770
        // the in-memory map. For any existing feature that does not exist in
3771
        // the in-memory map, delete it from the database.
3772
        for _, feature := range existingFeatures {
×
3773
                // The feature is still present, so there are no updates to be
×
3774
                // made.
×
3775
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3776
                        delete(newFeatures, feature.FeatureBit)
×
3777
                        continue
×
3778
                }
3779

3780
                // The feature is no longer present, so we remove it from the
3781
                // database.
3782
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3783
                        NodeID:     nodeID,
×
3784
                        FeatureBit: feature.FeatureBit,
×
3785
                })
×
3786
                if err != nil {
×
3787
                        return fmt.Errorf("unable to delete node(%d) "+
×
3788
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3789
                                err)
×
3790
                }
×
3791
        }
3792

3793
        // Any remaining entries in newFeatures are new features that need to be
3794
        // added to the database for the first time.
3795
        for feature := range newFeatures {
×
3796
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3797
                        NodeID:     nodeID,
×
3798
                        FeatureBit: feature,
×
3799
                })
×
3800
                if err != nil {
×
3801
                        return fmt.Errorf("unable to insert node(%d) "+
×
3802
                                "feature(%v): %w", nodeID, feature, err)
×
3803
                }
×
3804
        }
3805

3806
        return nil
×
3807
}
3808

3809
// fetchNodeFeatures fetches the features for a node with the given public key.
3810
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3811
        v lnwire.GossipVersion, nodePub route.Vertex) (*lnwire.FeatureVector,
NEW
3812
        error) {
×
3813

×
3814
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3815
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3816
                        PubKey:  nodePub[:],
×
NEW
3817
                        Version: int16(v),
×
3818
                },
×
3819
        )
×
3820
        if err != nil {
×
3821
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3822
                        nodePub, err)
×
3823
        }
×
3824

3825
        features := lnwire.EmptyFeatureVector()
×
3826
        for _, bit := range rows {
×
3827
                features.Set(lnwire.FeatureBit(bit))
×
3828
        }
×
3829

3830
        return features, nil
×
3831
}
3832

3833
// dbAddressType is an enum type that represents the different address types
3834
// that we store in the node_addresses table. The address type determines how
3835
// the address is to be serialised/deserialize.
3836
type dbAddressType uint8
3837

3838
const (
3839
        addressTypeIPv4   dbAddressType = 1
3840
        addressTypeIPv6   dbAddressType = 2
3841
        addressTypeTorV2  dbAddressType = 3
3842
        addressTypeTorV3  dbAddressType = 4
3843
        addressTypeDNS    dbAddressType = 5
3844
        addressTypeOpaque dbAddressType = math.MaxInt8
3845
)
3846

3847
// collectAddressRecords collects the addresses from the provided
3848
// net.Addr slice and returns a map of dbAddressType to a slice of address
3849
// strings.
3850
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3851
        error) {
×
3852

×
3853
        // Copy the nodes latest set of addresses.
×
3854
        newAddresses := map[dbAddressType][]string{
×
3855
                addressTypeIPv4:   {},
×
3856
                addressTypeIPv6:   {},
×
3857
                addressTypeTorV2:  {},
×
3858
                addressTypeTorV3:  {},
×
3859
                addressTypeDNS:    {},
×
3860
                addressTypeOpaque: {},
×
3861
        }
×
3862
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3863
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3864
        }
×
3865

3866
        for _, address := range addresses {
×
3867
                switch addr := address.(type) {
×
3868
                case *net.TCPAddr:
×
3869
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3870
                                addAddr(addressTypeIPv4, addr)
×
3871
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3872
                                addAddr(addressTypeIPv6, addr)
×
3873
                        } else {
×
3874
                                return nil, fmt.Errorf("unhandled IP "+
×
3875
                                        "address: %v", addr)
×
3876
                        }
×
3877

3878
                case *tor.OnionAddr:
×
3879
                        switch len(addr.OnionService) {
×
3880
                        case tor.V2Len:
×
3881
                                addAddr(addressTypeTorV2, addr)
×
3882
                        case tor.V3Len:
×
3883
                                addAddr(addressTypeTorV3, addr)
×
3884
                        default:
×
3885
                                return nil, fmt.Errorf("invalid length for " +
×
3886
                                        "a tor address")
×
3887
                        }
3888

3889
                case *lnwire.DNSAddress:
×
3890
                        addAddr(addressTypeDNS, addr)
×
3891

3892
                case *lnwire.OpaqueAddrs:
×
3893
                        addAddr(addressTypeOpaque, addr)
×
3894

3895
                default:
×
3896
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3897
                                addr)
×
3898
                }
3899
        }
3900

3901
        return newAddresses, nil
×
3902
}
3903

3904
// upsertNodeAddresses updates the node's addresses in the database. This
3905
// includes deleting any existing addresses and inserting the new set of
3906
// addresses. The deletion is necessary since the ordering of the addresses may
3907
// change, and we need to ensure that the database reflects the latest set of
3908
// addresses so that at the time of reconstructing the node announcement, the
3909
// order is preserved and the signature over the message remains valid.
3910
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3911
        addresses []net.Addr) error {
×
3912

×
3913
        // Delete any existing addresses for the node. This is required since
×
3914
        // even if the new set of addresses is the same, the ordering may have
×
3915
        // changed for a given address type.
×
3916
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3917
        if err != nil {
×
3918
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3919
                        nodeID, err)
×
3920
        }
×
3921

3922
        newAddresses, err := collectAddressRecords(addresses)
×
3923
        if err != nil {
×
3924
                return err
×
3925
        }
×
3926

3927
        // Any remaining entries in newAddresses are new addresses that need to
3928
        // be added to the database for the first time.
3929
        for addrType, addrList := range newAddresses {
×
3930
                for position, addr := range addrList {
×
3931
                        err := db.UpsertNodeAddress(
×
3932
                                ctx, sqlc.UpsertNodeAddressParams{
×
3933
                                        NodeID:   nodeID,
×
3934
                                        Type:     int16(addrType),
×
3935
                                        Address:  addr,
×
3936
                                        Position: int32(position),
×
3937
                                },
×
3938
                        )
×
3939
                        if err != nil {
×
3940
                                return fmt.Errorf("unable to insert "+
×
3941
                                        "node(%d) address(%v): %w", nodeID,
×
3942
                                        addr, err)
×
3943
                        }
×
3944
                }
3945
        }
3946

3947
        return nil
×
3948
}
3949

3950
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3951
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3952
        error) {
×
3953

×
3954
        // GetNodeAddresses ensures that the addresses for a given type are
×
3955
        // returned in the same order as they were inserted.
×
3956
        rows, err := db.GetNodeAddresses(ctx, id)
×
3957
        if err != nil {
×
3958
                return nil, err
×
3959
        }
×
3960

3961
        addresses := make([]net.Addr, 0, len(rows))
×
3962
        for _, row := range rows {
×
3963
                address := row.Address
×
3964

×
3965
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3966
                if err != nil {
×
3967
                        return nil, fmt.Errorf("unable to parse address "+
×
3968
                                "for node(%d): %v: %w", id, address, err)
×
3969
                }
×
3970

3971
                addresses = append(addresses, addr)
×
3972
        }
3973

3974
        // If we have no addresses, then we'll return nil instead of an
3975
        // empty slice.
3976
        if len(addresses) == 0 {
×
3977
                addresses = nil
×
3978
        }
×
3979

3980
        return addresses, nil
×
3981
}
3982

3983
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3984
// database. This includes updating any existing types, inserting any new types,
3985
// and deleting any types that are no longer present.
3986
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3987
        nodeID int64, extraFields map[uint64][]byte) error {
×
3988

×
3989
        // Get any existing extra signed fields for the node.
×
3990
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3991
        if err != nil {
×
3992
                return err
×
3993
        }
×
3994

3995
        // Make a lookup map of the existing field types so that we can use it
3996
        // to keep track of any fields we should delete.
3997
        m := make(map[uint64]bool)
×
3998
        for _, field := range existingFields {
×
3999
                m[uint64(field.Type)] = true
×
4000
        }
×
4001

4002
        // For all the new fields, we'll upsert them and remove them from the
4003
        // map of existing fields.
4004
        for tlvType, value := range extraFields {
×
4005
                err = db.UpsertNodeExtraType(
×
4006
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4007
                                NodeID: nodeID,
×
4008
                                Type:   int64(tlvType),
×
4009
                                Value:  value,
×
4010
                        },
×
4011
                )
×
4012
                if err != nil {
×
4013
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4014
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4015
                }
×
4016

4017
                // Remove the field from the map of existing fields if it was
4018
                // present.
4019
                delete(m, tlvType)
×
4020
        }
4021

4022
        // For all the fields that are left in the map of existing fields, we'll
4023
        // delete them as they are no longer present in the new set of fields.
4024
        for tlvType := range m {
×
4025
                err = db.DeleteExtraNodeType(
×
4026
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4027
                                NodeID: nodeID,
×
4028
                                Type:   int64(tlvType),
×
4029
                        },
×
4030
                )
×
4031
                if err != nil {
×
4032
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4033
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4034
                }
×
4035
        }
4036

4037
        return nil
×
4038
}
4039

4040
// srcNodeInfo holds the information about the source node of the graph.
4041
type srcNodeInfo struct {
4042
        // id is the DB level ID of the source node entry in the "nodes" table.
4043
        id int64
4044

4045
        // pub is the public key of the source node.
4046
        pub route.Vertex
4047
}
4048

4049
// sourceNode returns the DB node ID and pub key of the source node for the
4050
// specified protocol version.
4051
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4052
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4053

×
4054
        s.srcNodeMu.Lock()
×
4055
        defer s.srcNodeMu.Unlock()
×
4056

×
4057
        // If we already have the source node ID and pub key cached, then
×
4058
        // return them.
×
4059
        if info, ok := s.srcNodes[version]; ok {
×
4060
                return info.id, info.pub, nil
×
4061
        }
×
4062

4063
        var pubKey route.Vertex
×
4064

×
4065
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4066
        if err != nil {
×
4067
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4068
                        err)
×
4069
        }
×
4070

4071
        if len(nodes) == 0 {
×
4072
                return 0, pubKey, ErrSourceNodeNotSet
×
4073
        } else if len(nodes) > 1 {
×
4074
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4075
                        "protocol %s found", version)
×
4076
        }
×
4077

4078
        copy(pubKey[:], nodes[0].PubKey)
×
4079

×
4080
        s.srcNodes[version] = &srcNodeInfo{
×
4081
                id:  nodes[0].NodeID,
×
4082
                pub: pubKey,
×
4083
        }
×
4084

×
4085
        return nodes[0].NodeID, pubKey, nil
×
4086
}
4087

4088
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4089
// This then produces a map from TLV type to value. If the input is not a
4090
// valid TLV stream, then an error is returned.
4091
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4092
        r := bytes.NewReader(data)
×
4093

×
4094
        tlvStream, err := tlv.NewStream()
×
4095
        if err != nil {
×
4096
                return nil, err
×
4097
        }
×
4098

4099
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4100
        // pass it into the P2P decoding variant.
4101
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4102
        if err != nil {
×
4103
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4104
        }
×
4105
        if len(parsedTypes) == 0 {
×
4106
                return nil, nil
×
4107
        }
×
4108

4109
        records := make(map[uint64][]byte)
×
4110
        for k, v := range parsedTypes {
×
4111
                records[uint64(k)] = v
×
4112
        }
×
4113

4114
        return records, nil
×
4115
}
4116

4117
// insertChannel inserts a new channel record into the database.
4118
func insertChannel(ctx context.Context, db SQLQueries,
4119
        edge *models.ChannelEdgeInfo) error {
×
4120

×
4121
        // Make sure that at least a "shell" entry for each node is present in
×
4122
        // the nodes table.
×
4123
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4124
        if err != nil {
×
4125
                return fmt.Errorf("unable to create shell node: %w", err)
×
4126
        }
×
4127

4128
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4129
        if err != nil {
×
4130
                return fmt.Errorf("unable to create shell node: %w", err)
×
4131
        }
×
4132

4133
        var capacity sql.NullInt64
×
4134
        if edge.Capacity != 0 {
×
4135
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4136
        }
×
4137

4138
        createParams := sqlc.CreateChannelParams{
×
4139
                Version:     int16(lnwire.GossipVersion1),
×
4140
                Scid:        channelIDToBytes(edge.ChannelID),
×
4141
                NodeID1:     node1DBID,
×
4142
                NodeID2:     node2DBID,
×
4143
                Outpoint:    edge.ChannelPoint.String(),
×
4144
                Capacity:    capacity,
×
4145
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4146
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4147
        }
×
4148

×
4149
        if edge.AuthProof != nil {
×
4150
                proof := edge.AuthProof
×
4151

×
4152
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4153
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4154
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4155
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4156
        }
×
4157

4158
        // Insert the new channel record.
4159
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4160
        if err != nil {
×
4161
                return err
×
4162
        }
×
4163

4164
        // Insert any channel features.
4165
        for feature := range edge.Features.Features() {
×
4166
                err = db.InsertChannelFeature(
×
4167
                        ctx, sqlc.InsertChannelFeatureParams{
×
4168
                                ChannelID:  dbChanID,
×
4169
                                FeatureBit: int32(feature),
×
4170
                        },
×
4171
                )
×
4172
                if err != nil {
×
4173
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4174
                                "feature(%v): %w", dbChanID, feature, err)
×
4175
                }
×
4176
        }
4177

4178
        // Finally, insert any extra TLV fields in the channel announcement.
4179
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4180
        if err != nil {
×
4181
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4182
                        err)
×
4183
        }
×
4184

4185
        for tlvType, value := range extra {
×
4186
                err := db.UpsertChannelExtraType(
×
4187
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4188
                                ChannelID: dbChanID,
×
4189
                                Type:      int64(tlvType),
×
4190
                                Value:     value,
×
4191
                        },
×
4192
                )
×
4193
                if err != nil {
×
4194
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4195
                                "extra signed field(%v): %w", edge.ChannelID,
×
4196
                                tlvType, err)
×
4197
                }
×
4198
        }
4199

4200
        return nil
×
4201
}
4202

4203
// maybeCreateShellNode checks if a shell node entry exists for the
4204
// given public key. If it does not exist, then a new shell node entry is
4205
// created. The ID of the node is returned. A shell node only has a protocol
4206
// version and public key persisted.
4207
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4208
        pubKey route.Vertex) (int64, error) {
×
4209

×
4210
        dbNode, err := db.GetNodeByPubKey(
×
4211
                ctx, sqlc.GetNodeByPubKeyParams{
×
4212
                        PubKey:  pubKey[:],
×
4213
                        Version: int16(lnwire.GossipVersion1),
×
4214
                },
×
4215
        )
×
4216
        // The node exists. Return the ID.
×
4217
        if err == nil {
×
4218
                return dbNode.ID, nil
×
4219
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4220
                return 0, err
×
4221
        }
×
4222

4223
        // Otherwise, the node does not exist, so we create a shell entry for
4224
        // it.
4225
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4226
                Version: int16(lnwire.GossipVersion1),
×
4227
                PubKey:  pubKey[:],
×
4228
        })
×
4229
        if err != nil {
×
4230
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4231
        }
×
4232

4233
        return id, nil
×
4234
}
4235

4236
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4237
// the database. This includes deleting any existing types and then inserting
4238
// the new types.
4239
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4240
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4241

×
4242
        // Delete all existing extra signed fields for the channel policy.
×
4243
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4244
        if err != nil {
×
4245
                return fmt.Errorf("unable to delete "+
×
4246
                        "existing policy extra signed fields for policy %d: %w",
×
4247
                        chanPolicyID, err)
×
4248
        }
×
4249

4250
        // Insert all new extra signed fields for the channel policy.
4251
        for tlvType, value := range extraFields {
×
4252
                err = db.UpsertChanPolicyExtraType(
×
4253
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4254
                                ChannelPolicyID: chanPolicyID,
×
4255
                                Type:            int64(tlvType),
×
4256
                                Value:           value,
×
4257
                        },
×
4258
                )
×
4259
                if err != nil {
×
4260
                        return fmt.Errorf("unable to insert "+
×
4261
                                "channel_policy(%d) extra signed field(%v): %w",
×
4262
                                chanPolicyID, tlvType, err)
×
4263
                }
×
4264
        }
4265

4266
        return nil
×
4267
}
4268

4269
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4270
// provided dbChanRow and also fetches any other required information
4271
// to construct the edge info.
4272
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4273
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4274
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4275

×
4276
        data, err := batchLoadChannelData(
×
4277
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4278
        )
×
4279
        if err != nil {
×
4280
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4281
                        err)
×
4282
        }
×
4283

4284
        return buildEdgeInfoWithBatchData(
×
4285
                cfg.ChainHash, dbChan, node1, node2, data,
×
4286
        )
×
4287
}
4288

4289
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4290
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4291
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4292
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4293

×
4294
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4295
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4296
                        dbChan.Version)
×
4297
        }
×
4298

4299
        // Use pre-loaded features and extras types.
4300
        fv := lnwire.EmptyFeatureVector()
×
4301
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4302
                for _, bit := range features {
×
4303
                        fv.Set(lnwire.FeatureBit(bit))
×
4304
                }
×
4305
        }
4306

4307
        var extras map[uint64][]byte
×
4308
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4309
        if exists {
×
4310
                extras = channelExtras
×
4311
        } else {
×
4312
                extras = make(map[uint64][]byte)
×
4313
        }
×
4314

4315
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4316
        if err != nil {
×
4317
                return nil, err
×
4318
        }
×
4319

4320
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4321
        if err != nil {
×
4322
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4323
                        "fields: %w", err)
×
4324
        }
×
4325
        if recs == nil {
×
4326
                recs = make([]byte, 0)
×
4327
        }
×
4328

4329
        var btcKey1, btcKey2 route.Vertex
×
4330
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4331
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4332

×
4333
        channel := &models.ChannelEdgeInfo{
×
4334
                ChainHash:        chain,
×
4335
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4336
                NodeKey1Bytes:    node1,
×
4337
                NodeKey2Bytes:    node2,
×
4338
                BitcoinKey1Bytes: btcKey1,
×
4339
                BitcoinKey2Bytes: btcKey2,
×
4340
                ChannelPoint:     *op,
×
4341
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4342
                Features:         fv,
×
4343
                ExtraOpaqueData:  recs,
×
4344
        }
×
4345

×
4346
        // We always set all the signatures at the same time, so we can
×
4347
        // safely check if one signature is present to determine if we have the
×
4348
        // rest of the signatures for the auth proof.
×
4349
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4350
                channel.AuthProof = &models.ChannelAuthProof{
×
4351
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4352
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4353
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4354
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4355
                }
×
4356
        }
×
4357

4358
        return channel, nil
×
4359
}
4360

4361
// buildNodeVertices is a helper that converts raw node public keys
4362
// into route.Vertex instances.
4363
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4364
        route.Vertex, error) {
×
4365

×
4366
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4367
        if err != nil {
×
4368
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4369
                        "create vertex from node1 pubkey: %w", err)
×
4370
        }
×
4371

4372
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4373
        if err != nil {
×
4374
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4375
                        "create vertex from node2 pubkey: %w", err)
×
4376
        }
×
4377

4378
        return node1Vertex, node2Vertex, nil
×
4379
}
4380

4381
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4382
// retrieves all the extra info required to build the complete
4383
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4384
// the provided sqlc.GraphChannelPolicy records are nil.
4385
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4386
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4387
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4388
        *models.ChannelEdgePolicy, error) {
×
4389

×
4390
        if dbPol1 == nil && dbPol2 == nil {
×
4391
                return nil, nil, nil
×
4392
        }
×
4393

4394
        var policyIDs = make([]int64, 0, 2)
×
4395
        if dbPol1 != nil {
×
4396
                policyIDs = append(policyIDs, dbPol1.ID)
×
4397
        }
×
4398
        if dbPol2 != nil {
×
4399
                policyIDs = append(policyIDs, dbPol2.ID)
×
4400
        }
×
4401

4402
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4403
        if err != nil {
×
4404
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4405
                        "data: %w", err)
×
4406
        }
×
4407

4408
        pol1, err := buildChanPolicyWithBatchData(
×
4409
                dbPol1, channelID, node2, batchData,
×
4410
        )
×
4411
        if err != nil {
×
4412
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4413
        }
×
4414

4415
        pol2, err := buildChanPolicyWithBatchData(
×
4416
                dbPol2, channelID, node1, batchData,
×
4417
        )
×
4418
        if err != nil {
×
4419
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4420
        }
×
4421

4422
        return pol1, pol2, nil
×
4423
}
4424

4425
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4426
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4427
// then nil is returned for it.
4428
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4429
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4430
        *models.CachedEdgePolicy, error) {
×
4431

×
4432
        var p1, p2 *models.CachedEdgePolicy
×
4433
        if dbPol1 != nil {
×
4434
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4435
                if err != nil {
×
4436
                        return nil, nil, err
×
4437
                }
×
4438

4439
                p1 = models.NewCachedPolicy(policy1)
×
4440
        }
4441
        if dbPol2 != nil {
×
4442
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4443
                if err != nil {
×
4444
                        return nil, nil, err
×
4445
                }
×
4446

4447
                p2 = models.NewCachedPolicy(policy2)
×
4448
        }
4449

4450
        return p1, p2, nil
×
4451
}
4452

4453
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4454
// provided sqlc.GraphChannelPolicy and other required information.
4455
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4456
        extras map[uint64][]byte,
4457
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4458

×
4459
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4460
        if err != nil {
×
4461
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4462
                        "fields: %w", err)
×
4463
        }
×
4464

4465
        var inboundFee fn.Option[lnwire.Fee]
×
4466
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4467
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4468

×
4469
                inboundFee = fn.Some(lnwire.Fee{
×
4470
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4471
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4472
                })
×
4473
        }
×
4474

4475
        return &models.ChannelEdgePolicy{
×
4476
                SigBytes:  dbPolicy.Signature,
×
4477
                ChannelID: channelID,
×
4478
                LastUpdate: time.Unix(
×
4479
                        dbPolicy.LastUpdate.Int64, 0,
×
4480
                ),
×
4481
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4482
                        dbPolicy.MessageFlags,
×
4483
                ),
×
4484
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4485
                        dbPolicy.ChannelFlags,
×
4486
                ),
×
4487
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4488
                MinHTLC: lnwire.MilliSatoshi(
×
4489
                        dbPolicy.MinHtlcMsat,
×
4490
                ),
×
4491
                MaxHTLC: lnwire.MilliSatoshi(
×
4492
                        dbPolicy.MaxHtlcMsat.Int64,
×
4493
                ),
×
4494
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4495
                        dbPolicy.BaseFeeMsat,
×
4496
                ),
×
4497
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4498
                ToNode:                    toNode,
×
4499
                InboundFee:                inboundFee,
×
4500
                ExtraOpaqueData:           recs,
×
4501
        }, nil
×
4502
}
4503

4504
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4505
// row which is expected to be a sqlc type that contains channel policy
4506
// information. It returns two policies, which may be nil if the policy
4507
// information is not present in the row.
4508
//
4509
//nolint:ll,dupl,funlen
4510
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4511
        *sqlc.GraphChannelPolicy, error) {
×
4512

×
4513
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4514
        switch r := row.(type) {
×
4515
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4516
                if r.Policy1Timelock.Valid {
×
4517
                        policy1 = &sqlc.GraphChannelPolicy{
×
4518
                                Timelock:                r.Policy1Timelock.Int32,
×
4519
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4520
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4521
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4522
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4523
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4524
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4525
                                Disabled:                r.Policy1Disabled,
×
4526
                                MessageFlags:            r.Policy1MessageFlags,
×
4527
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4528
                        }
×
4529
                }
×
4530
                if r.Policy2Timelock.Valid {
×
4531
                        policy2 = &sqlc.GraphChannelPolicy{
×
4532
                                Timelock:                r.Policy2Timelock.Int32,
×
4533
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4534
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4535
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4536
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4537
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4538
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4539
                                Disabled:                r.Policy2Disabled,
×
4540
                                MessageFlags:            r.Policy2MessageFlags,
×
4541
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4542
                        }
×
4543
                }
×
4544

4545
                return policy1, policy2, nil
×
4546

4547
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4548
                if r.Policy1ID.Valid {
×
4549
                        policy1 = &sqlc.GraphChannelPolicy{
×
4550
                                ID:                      r.Policy1ID.Int64,
×
4551
                                Version:                 r.Policy1Version.Int16,
×
4552
                                ChannelID:               r.GraphChannel.ID,
×
4553
                                NodeID:                  r.Policy1NodeID.Int64,
×
4554
                                Timelock:                r.Policy1Timelock.Int32,
×
4555
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4556
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4557
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4558
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4559
                                LastUpdate:              r.Policy1LastUpdate,
×
4560
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4561
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4562
                                Disabled:                r.Policy1Disabled,
×
4563
                                MessageFlags:            r.Policy1MessageFlags,
×
4564
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4565
                                Signature:               r.Policy1Signature,
×
4566
                        }
×
4567
                }
×
4568
                if r.Policy2ID.Valid {
×
4569
                        policy2 = &sqlc.GraphChannelPolicy{
×
4570
                                ID:                      r.Policy2ID.Int64,
×
4571
                                Version:                 r.Policy2Version.Int16,
×
4572
                                ChannelID:               r.GraphChannel.ID,
×
4573
                                NodeID:                  r.Policy2NodeID.Int64,
×
4574
                                Timelock:                r.Policy2Timelock.Int32,
×
4575
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4576
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4577
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4578
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4579
                                LastUpdate:              r.Policy2LastUpdate,
×
4580
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4581
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4582
                                Disabled:                r.Policy2Disabled,
×
4583
                                MessageFlags:            r.Policy2MessageFlags,
×
4584
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4585
                                Signature:               r.Policy2Signature,
×
4586
                        }
×
4587
                }
×
4588

4589
                return policy1, policy2, nil
×
4590

4591
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4592
                if r.Policy1ID.Valid {
×
4593
                        policy1 = &sqlc.GraphChannelPolicy{
×
4594
                                ID:                      r.Policy1ID.Int64,
×
4595
                                Version:                 r.Policy1Version.Int16,
×
4596
                                ChannelID:               r.GraphChannel.ID,
×
4597
                                NodeID:                  r.Policy1NodeID.Int64,
×
4598
                                Timelock:                r.Policy1Timelock.Int32,
×
4599
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4600
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4601
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4602
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4603
                                LastUpdate:              r.Policy1LastUpdate,
×
4604
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4605
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4606
                                Disabled:                r.Policy1Disabled,
×
4607
                                MessageFlags:            r.Policy1MessageFlags,
×
4608
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4609
                                Signature:               r.Policy1Signature,
×
4610
                        }
×
4611
                }
×
4612
                if r.Policy2ID.Valid {
×
4613
                        policy2 = &sqlc.GraphChannelPolicy{
×
4614
                                ID:                      r.Policy2ID.Int64,
×
4615
                                Version:                 r.Policy2Version.Int16,
×
4616
                                ChannelID:               r.GraphChannel.ID,
×
4617
                                NodeID:                  r.Policy2NodeID.Int64,
×
4618
                                Timelock:                r.Policy2Timelock.Int32,
×
4619
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4620
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4621
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4622
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4623
                                LastUpdate:              r.Policy2LastUpdate,
×
4624
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4625
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4626
                                Disabled:                r.Policy2Disabled,
×
4627
                                MessageFlags:            r.Policy2MessageFlags,
×
4628
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4629
                                Signature:               r.Policy2Signature,
×
4630
                        }
×
4631
                }
×
4632

4633
                return policy1, policy2, nil
×
4634

4635
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4636
                if r.Policy1ID.Valid {
×
4637
                        policy1 = &sqlc.GraphChannelPolicy{
×
4638
                                ID:                      r.Policy1ID.Int64,
×
4639
                                Version:                 r.Policy1Version.Int16,
×
4640
                                ChannelID:               r.GraphChannel.ID,
×
4641
                                NodeID:                  r.Policy1NodeID.Int64,
×
4642
                                Timelock:                r.Policy1Timelock.Int32,
×
4643
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4644
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4645
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4646
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4647
                                LastUpdate:              r.Policy1LastUpdate,
×
4648
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4649
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4650
                                Disabled:                r.Policy1Disabled,
×
4651
                                MessageFlags:            r.Policy1MessageFlags,
×
4652
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4653
                                Signature:               r.Policy1Signature,
×
4654
                        }
×
4655
                }
×
4656
                if r.Policy2ID.Valid {
×
4657
                        policy2 = &sqlc.GraphChannelPolicy{
×
4658
                                ID:                      r.Policy2ID.Int64,
×
4659
                                Version:                 r.Policy2Version.Int16,
×
4660
                                ChannelID:               r.GraphChannel.ID,
×
4661
                                NodeID:                  r.Policy2NodeID.Int64,
×
4662
                                Timelock:                r.Policy2Timelock.Int32,
×
4663
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4664
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4665
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4666
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4667
                                LastUpdate:              r.Policy2LastUpdate,
×
4668
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4669
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4670
                                Disabled:                r.Policy2Disabled,
×
4671
                                MessageFlags:            r.Policy2MessageFlags,
×
4672
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4673
                                Signature:               r.Policy2Signature,
×
4674
                        }
×
4675
                }
×
4676

4677
                return policy1, policy2, nil
×
4678

4679
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4680
                if r.Policy1ID.Valid {
×
4681
                        policy1 = &sqlc.GraphChannelPolicy{
×
4682
                                ID:                      r.Policy1ID.Int64,
×
4683
                                Version:                 r.Policy1Version.Int16,
×
4684
                                ChannelID:               r.GraphChannel.ID,
×
4685
                                NodeID:                  r.Policy1NodeID.Int64,
×
4686
                                Timelock:                r.Policy1Timelock.Int32,
×
4687
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4688
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4689
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4690
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4691
                                LastUpdate:              r.Policy1LastUpdate,
×
4692
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4693
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4694
                                Disabled:                r.Policy1Disabled,
×
4695
                                MessageFlags:            r.Policy1MessageFlags,
×
4696
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4697
                                Signature:               r.Policy1Signature,
×
4698
                        }
×
4699
                }
×
4700
                if r.Policy2ID.Valid {
×
4701
                        policy2 = &sqlc.GraphChannelPolicy{
×
4702
                                ID:                      r.Policy2ID.Int64,
×
4703
                                Version:                 r.Policy2Version.Int16,
×
4704
                                ChannelID:               r.GraphChannel.ID,
×
4705
                                NodeID:                  r.Policy2NodeID.Int64,
×
4706
                                Timelock:                r.Policy2Timelock.Int32,
×
4707
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4708
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4709
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4710
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4711
                                LastUpdate:              r.Policy2LastUpdate,
×
4712
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4713
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4714
                                Disabled:                r.Policy2Disabled,
×
4715
                                MessageFlags:            r.Policy2MessageFlags,
×
4716
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4717
                                Signature:               r.Policy2Signature,
×
4718
                        }
×
4719
                }
×
4720

4721
                return policy1, policy2, nil
×
4722

4723
        case sqlc.ListChannelsForNodeIDsRow:
×
4724
                if r.Policy1ID.Valid {
×
4725
                        policy1 = &sqlc.GraphChannelPolicy{
×
4726
                                ID:                      r.Policy1ID.Int64,
×
4727
                                Version:                 r.Policy1Version.Int16,
×
4728
                                ChannelID:               r.GraphChannel.ID,
×
4729
                                NodeID:                  r.Policy1NodeID.Int64,
×
4730
                                Timelock:                r.Policy1Timelock.Int32,
×
4731
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4732
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4733
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4734
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4735
                                LastUpdate:              r.Policy1LastUpdate,
×
4736
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4737
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4738
                                Disabled:                r.Policy1Disabled,
×
4739
                                MessageFlags:            r.Policy1MessageFlags,
×
4740
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4741
                                Signature:               r.Policy1Signature,
×
4742
                        }
×
4743
                }
×
4744
                if r.Policy2ID.Valid {
×
4745
                        policy2 = &sqlc.GraphChannelPolicy{
×
4746
                                ID:                      r.Policy2ID.Int64,
×
4747
                                Version:                 r.Policy2Version.Int16,
×
4748
                                ChannelID:               r.GraphChannel.ID,
×
4749
                                NodeID:                  r.Policy2NodeID.Int64,
×
4750
                                Timelock:                r.Policy2Timelock.Int32,
×
4751
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4752
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4753
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4754
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4755
                                LastUpdate:              r.Policy2LastUpdate,
×
4756
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4757
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4758
                                Disabled:                r.Policy2Disabled,
×
4759
                                MessageFlags:            r.Policy2MessageFlags,
×
4760
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4761
                                Signature:               r.Policy2Signature,
×
4762
                        }
×
4763
                }
×
4764

4765
                return policy1, policy2, nil
×
4766

4767
        case sqlc.ListChannelsByNodeIDRow:
×
4768
                if r.Policy1ID.Valid {
×
4769
                        policy1 = &sqlc.GraphChannelPolicy{
×
4770
                                ID:                      r.Policy1ID.Int64,
×
4771
                                Version:                 r.Policy1Version.Int16,
×
4772
                                ChannelID:               r.GraphChannel.ID,
×
4773
                                NodeID:                  r.Policy1NodeID.Int64,
×
4774
                                Timelock:                r.Policy1Timelock.Int32,
×
4775
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4776
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4777
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4778
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4779
                                LastUpdate:              r.Policy1LastUpdate,
×
4780
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4781
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4782
                                Disabled:                r.Policy1Disabled,
×
4783
                                MessageFlags:            r.Policy1MessageFlags,
×
4784
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4785
                                Signature:               r.Policy1Signature,
×
4786
                        }
×
4787
                }
×
4788
                if r.Policy2ID.Valid {
×
4789
                        policy2 = &sqlc.GraphChannelPolicy{
×
4790
                                ID:                      r.Policy2ID.Int64,
×
4791
                                Version:                 r.Policy2Version.Int16,
×
4792
                                ChannelID:               r.GraphChannel.ID,
×
4793
                                NodeID:                  r.Policy2NodeID.Int64,
×
4794
                                Timelock:                r.Policy2Timelock.Int32,
×
4795
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4796
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4797
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4798
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4799
                                LastUpdate:              r.Policy2LastUpdate,
×
4800
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4801
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4802
                                Disabled:                r.Policy2Disabled,
×
4803
                                MessageFlags:            r.Policy2MessageFlags,
×
4804
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4805
                                Signature:               r.Policy2Signature,
×
4806
                        }
×
4807
                }
×
4808

4809
                return policy1, policy2, nil
×
4810

4811
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4812
                if r.Policy1ID.Valid {
×
4813
                        policy1 = &sqlc.GraphChannelPolicy{
×
4814
                                ID:                      r.Policy1ID.Int64,
×
4815
                                Version:                 r.Policy1Version.Int16,
×
4816
                                ChannelID:               r.GraphChannel.ID,
×
4817
                                NodeID:                  r.Policy1NodeID.Int64,
×
4818
                                Timelock:                r.Policy1Timelock.Int32,
×
4819
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4820
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4821
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4822
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4823
                                LastUpdate:              r.Policy1LastUpdate,
×
4824
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4825
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4826
                                Disabled:                r.Policy1Disabled,
×
4827
                                MessageFlags:            r.Policy1MessageFlags,
×
4828
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4829
                                Signature:               r.Policy1Signature,
×
4830
                        }
×
4831
                }
×
4832
                if r.Policy2ID.Valid {
×
4833
                        policy2 = &sqlc.GraphChannelPolicy{
×
4834
                                ID:                      r.Policy2ID.Int64,
×
4835
                                Version:                 r.Policy2Version.Int16,
×
4836
                                ChannelID:               r.GraphChannel.ID,
×
4837
                                NodeID:                  r.Policy2NodeID.Int64,
×
4838
                                Timelock:                r.Policy2Timelock.Int32,
×
4839
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4840
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4841
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4842
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4843
                                LastUpdate:              r.Policy2LastUpdate,
×
4844
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4845
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4846
                                Disabled:                r.Policy2Disabled,
×
4847
                                MessageFlags:            r.Policy2MessageFlags,
×
4848
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4849
                                Signature:               r.Policy2Signature,
×
4850
                        }
×
4851
                }
×
4852

4853
                return policy1, policy2, nil
×
4854

4855
        case sqlc.GetChannelsByIDsRow:
×
4856
                if r.Policy1ID.Valid {
×
4857
                        policy1 = &sqlc.GraphChannelPolicy{
×
4858
                                ID:                      r.Policy1ID.Int64,
×
4859
                                Version:                 r.Policy1Version.Int16,
×
4860
                                ChannelID:               r.GraphChannel.ID,
×
4861
                                NodeID:                  r.Policy1NodeID.Int64,
×
4862
                                Timelock:                r.Policy1Timelock.Int32,
×
4863
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4864
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4865
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4866
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4867
                                LastUpdate:              r.Policy1LastUpdate,
×
4868
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4869
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4870
                                Disabled:                r.Policy1Disabled,
×
4871
                                MessageFlags:            r.Policy1MessageFlags,
×
4872
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4873
                                Signature:               r.Policy1Signature,
×
4874
                        }
×
4875
                }
×
4876
                if r.Policy2ID.Valid {
×
4877
                        policy2 = &sqlc.GraphChannelPolicy{
×
4878
                                ID:                      r.Policy2ID.Int64,
×
4879
                                Version:                 r.Policy2Version.Int16,
×
4880
                                ChannelID:               r.GraphChannel.ID,
×
4881
                                NodeID:                  r.Policy2NodeID.Int64,
×
4882
                                Timelock:                r.Policy2Timelock.Int32,
×
4883
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4884
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4885
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4886
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4887
                                LastUpdate:              r.Policy2LastUpdate,
×
4888
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4889
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4890
                                Disabled:                r.Policy2Disabled,
×
4891
                                MessageFlags:            r.Policy2MessageFlags,
×
4892
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4893
                                Signature:               r.Policy2Signature,
×
4894
                        }
×
4895
                }
×
4896

4897
                return policy1, policy2, nil
×
4898

4899
        default:
×
4900
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4901
                        "extractChannelPolicies: %T", r)
×
4902
        }
4903
}
4904

4905
// channelIDToBytes converts a channel ID (SCID) to a byte array
4906
// representation.
4907
func channelIDToBytes(channelID uint64) []byte {
×
4908
        var chanIDB [8]byte
×
4909
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4910

×
4911
        return chanIDB[:]
×
4912
}
×
4913

4914
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4915
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4916
        if len(addresses) == 0 {
×
4917
                return nil, nil
×
4918
        }
×
4919

4920
        result := make([]net.Addr, 0, len(addresses))
×
4921
        for _, addr := range addresses {
×
4922
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4923
                if err != nil {
×
4924
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4925
                                "of type %d: %w", addr.address, addr.addrType,
×
4926
                                err)
×
4927
                }
×
4928
                if netAddr != nil {
×
4929
                        result = append(result, netAddr)
×
4930
                }
×
4931
        }
4932

4933
        // If we have no valid addresses, return nil instead of empty slice.
4934
        if len(result) == 0 {
×
4935
                return nil, nil
×
4936
        }
×
4937

4938
        return result, nil
×
4939
}
4940

4941
// parseAddress parses the given address string based on the address type
4942
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4943
// and opaque addresses.
4944
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4945
        switch addrType {
×
4946
        case addressTypeIPv4:
×
4947
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4948
                if err != nil {
×
4949
                        return nil, err
×
4950
                }
×
4951

4952
                tcp.IP = tcp.IP.To4()
×
4953

×
4954
                return tcp, nil
×
4955

4956
        case addressTypeIPv6:
×
4957
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4958
                if err != nil {
×
4959
                        return nil, err
×
4960
                }
×
4961

4962
                return tcp, nil
×
4963

4964
        case addressTypeTorV3, addressTypeTorV2:
×
4965
                service, portStr, err := net.SplitHostPort(address)
×
4966
                if err != nil {
×
4967
                        return nil, fmt.Errorf("unable to split tor "+
×
4968
                                "address: %v", address)
×
4969
                }
×
4970

4971
                port, err := strconv.Atoi(portStr)
×
4972
                if err != nil {
×
4973
                        return nil, err
×
4974
                }
×
4975

4976
                return &tor.OnionAddr{
×
4977
                        OnionService: service,
×
4978
                        Port:         port,
×
4979
                }, nil
×
4980

4981
        case addressTypeDNS:
×
4982
                hostname, portStr, err := net.SplitHostPort(address)
×
4983
                if err != nil {
×
4984
                        return nil, fmt.Errorf("unable to split DNS "+
×
4985
                                "address: %v", address)
×
4986
                }
×
4987

4988
                port, err := strconv.Atoi(portStr)
×
4989
                if err != nil {
×
4990
                        return nil, err
×
4991
                }
×
4992

4993
                return &lnwire.DNSAddress{
×
4994
                        Hostname: hostname,
×
4995
                        Port:     uint16(port),
×
4996
                }, nil
×
4997

4998
        case addressTypeOpaque:
×
4999
                opaque, err := hex.DecodeString(address)
×
5000
                if err != nil {
×
5001
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5002
                                "address: %v", address)
×
5003
                }
×
5004

5005
                return &lnwire.OpaqueAddrs{
×
5006
                        Payload: opaque,
×
5007
                }, nil
×
5008

5009
        default:
×
5010
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5011
        }
5012
}
5013

5014
// batchNodeData holds all the related data for a batch of nodes.
5015
type batchNodeData struct {
5016
        // features is a map from a DB node ID to the feature bits for that
5017
        // node.
5018
        features map[int64][]int
5019

5020
        // addresses is a map from a DB node ID to the node's addresses.
5021
        addresses map[int64][]nodeAddress
5022

5023
        // extraFields is a map from a DB node ID to the extra signed fields
5024
        // for that node.
5025
        extraFields map[int64]map[uint64][]byte
5026
}
5027

5028
// nodeAddress holds the address type, position and address string for a
5029
// node. This is used to batch the fetching of node addresses.
5030
type nodeAddress struct {
5031
        addrType dbAddressType
5032
        position int32
5033
        address  string
5034
}
5035

5036
// batchLoadNodeData loads all related data for a batch of node IDs using the
5037
// provided SQLQueries interface. It returns a batchNodeData instance containing
5038
// the node features, addresses and extra signed fields.
5039
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5040
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5041

×
5042
        // Batch load the node features.
×
5043
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5044
        if err != nil {
×
5045
                return nil, fmt.Errorf("unable to batch load node "+
×
5046
                        "features: %w", err)
×
5047
        }
×
5048

5049
        // Batch load the node addresses.
5050
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5051
        if err != nil {
×
5052
                return nil, fmt.Errorf("unable to batch load node "+
×
5053
                        "addresses: %w", err)
×
5054
        }
×
5055

5056
        // Batch load the node extra signed fields.
5057
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5058
        if err != nil {
×
5059
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5060
                        "signed fields: %w", err)
×
5061
        }
×
5062

5063
        return &batchNodeData{
×
5064
                features:    features,
×
5065
                addresses:   addrs,
×
5066
                extraFields: extraTypes,
×
5067
        }, nil
×
5068
}
5069

5070
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5071
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5072
func batchLoadNodeFeaturesHelper(ctx context.Context,
5073
        cfg *sqldb.QueryConfig, db SQLQueries,
5074
        nodeIDs []int64) (map[int64][]int, error) {
×
5075

×
5076
        features := make(map[int64][]int)
×
5077

×
5078
        return features, sqldb.ExecuteBatchQuery(
×
5079
                ctx, cfg, nodeIDs,
×
5080
                func(id int64) int64 {
×
5081
                        return id
×
5082
                },
×
5083
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5084
                        error) {
×
5085

×
5086
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5087
                },
×
5088
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5089
                        features[feature.NodeID] = append(
×
5090
                                features[feature.NodeID],
×
5091
                                int(feature.FeatureBit),
×
5092
                        )
×
5093

×
5094
                        return nil
×
5095
                },
×
5096
        )
5097
}
5098

5099
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5100
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5101
// node ID to a slice of nodeAddress structs.
5102
func batchLoadNodeAddressesHelper(ctx context.Context,
5103
        cfg *sqldb.QueryConfig, db SQLQueries,
5104
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5105

×
5106
        addrs := make(map[int64][]nodeAddress)
×
5107

×
5108
        return addrs, sqldb.ExecuteBatchQuery(
×
5109
                ctx, cfg, nodeIDs,
×
5110
                func(id int64) int64 {
×
5111
                        return id
×
5112
                },
×
5113
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5114
                        error) {
×
5115

×
5116
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5117
                },
×
5118
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5119
                        addrs[addr.NodeID] = append(
×
5120
                                addrs[addr.NodeID], nodeAddress{
×
5121
                                        addrType: dbAddressType(addr.Type),
×
5122
                                        position: addr.Position,
×
5123
                                        address:  addr.Address,
×
5124
                                },
×
5125
                        )
×
5126

×
5127
                        return nil
×
5128
                },
×
5129
        )
5130
}
5131

5132
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5133
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5134
// query.
5135
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5136
        cfg *sqldb.QueryConfig, db SQLQueries,
5137
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5138

×
5139
        extraFields := make(map[int64]map[uint64][]byte)
×
5140

×
5141
        callback := func(ctx context.Context,
×
5142
                field sqlc.GraphNodeExtraType) error {
×
5143

×
5144
                if extraFields[field.NodeID] == nil {
×
5145
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5146
                }
×
5147
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5148

×
5149
                return nil
×
5150
        }
5151

5152
        return extraFields, sqldb.ExecuteBatchQuery(
×
5153
                ctx, cfg, nodeIDs,
×
5154
                func(id int64) int64 {
×
5155
                        return id
×
5156
                },
×
5157
                func(ctx context.Context, ids []int64) (
5158
                        []sqlc.GraphNodeExtraType, error) {
×
5159

×
5160
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5161
                },
×
5162
                callback,
5163
        )
5164
}
5165

5166
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5167
// from the provided sqlc.GraphChannelPolicy records and the
5168
// provided batchChannelData.
5169
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5170
        channelID uint64, node1, node2 route.Vertex,
5171
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5172
        *models.ChannelEdgePolicy, error) {
×
5173

×
5174
        pol1, err := buildChanPolicyWithBatchData(
×
5175
                dbPol1, channelID, node2, batchData,
×
5176
        )
×
5177
        if err != nil {
×
5178
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5179
        }
×
5180

5181
        pol2, err := buildChanPolicyWithBatchData(
×
5182
                dbPol2, channelID, node1, batchData,
×
5183
        )
×
5184
        if err != nil {
×
5185
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5186
        }
×
5187

5188
        return pol1, pol2, nil
×
5189
}
5190

5191
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5192
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5193
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5194
        channelID uint64, toNode route.Vertex,
5195
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5196

×
5197
        if dbPol == nil {
×
5198
                return nil, nil
×
5199
        }
×
5200

5201
        var dbPol1Extras map[uint64][]byte
×
5202
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5203
                dbPol1Extras = extras
×
5204
        } else {
×
5205
                dbPol1Extras = make(map[uint64][]byte)
×
5206
        }
×
5207

5208
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5209
}
5210

5211
// batchChannelData holds all the related data for a batch of channels.
5212
type batchChannelData struct {
5213
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5214
        chanfeatures map[int64][]int
5215

5216
        // chanExtras is a map from DB channel ID to a map of TLV type to
5217
        // extra signed field bytes.
5218
        chanExtraTypes map[int64]map[uint64][]byte
5219

5220
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5221
        // to extra signed field bytes.
5222
        policyExtras map[int64]map[uint64][]byte
5223
}
5224

5225
// batchLoadChannelData loads all related data for batches of channels and
5226
// policies.
5227
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5228
        db SQLQueries, channelIDs []int64,
5229
        policyIDs []int64) (*batchChannelData, error) {
×
5230

×
5231
        batchData := &batchChannelData{
×
5232
                chanfeatures:   make(map[int64][]int),
×
5233
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5234
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5235
        }
×
5236

×
5237
        // Batch load channel features and extras
×
5238
        var err error
×
5239
        if len(channelIDs) > 0 {
×
5240
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5241
                        ctx, cfg, db, channelIDs,
×
5242
                )
×
5243
                if err != nil {
×
5244
                        return nil, fmt.Errorf("unable to batch load "+
×
5245
                                "channel features: %w", err)
×
5246
                }
×
5247

5248
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5249
                        ctx, cfg, db, channelIDs,
×
5250
                )
×
5251
                if err != nil {
×
5252
                        return nil, fmt.Errorf("unable to batch load "+
×
5253
                                "channel extras: %w", err)
×
5254
                }
×
5255
        }
5256

5257
        if len(policyIDs) > 0 {
×
5258
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5259
                        ctx, cfg, db, policyIDs,
×
5260
                )
×
5261
                if err != nil {
×
5262
                        return nil, fmt.Errorf("unable to batch load "+
×
5263
                                "policy extras: %w", err)
×
5264
                }
×
5265
                batchData.policyExtras = policyExtras
×
5266
        }
5267

5268
        return batchData, nil
×
5269
}
5270

5271
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5272
// channel IDs using ExecuteBatchQuery wrapper around the
5273
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5274
// slice of feature bits.
5275
func batchLoadChannelFeaturesHelper(ctx context.Context,
5276
        cfg *sqldb.QueryConfig, db SQLQueries,
5277
        channelIDs []int64) (map[int64][]int, error) {
×
5278

×
5279
        features := make(map[int64][]int)
×
5280

×
5281
        return features, sqldb.ExecuteBatchQuery(
×
5282
                ctx, cfg, channelIDs,
×
5283
                func(id int64) int64 {
×
5284
                        return id
×
5285
                },
×
5286
                func(ctx context.Context,
5287
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5288

×
5289
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5290
                },
×
5291
                func(ctx context.Context,
5292
                        feature sqlc.GraphChannelFeature) error {
×
5293

×
5294
                        features[feature.ChannelID] = append(
×
5295
                                features[feature.ChannelID],
×
5296
                                int(feature.FeatureBit),
×
5297
                        )
×
5298

×
5299
                        return nil
×
5300
                },
×
5301
        )
5302
}
5303

5304
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5305
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5306
// query. It returns a map from DB channel ID to a map of TLV type to extra
5307
// signed field bytes.
5308
func batchLoadChannelExtrasHelper(ctx context.Context,
5309
        cfg *sqldb.QueryConfig, db SQLQueries,
5310
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5311

×
5312
        extras := make(map[int64]map[uint64][]byte)
×
5313

×
5314
        cb := func(ctx context.Context,
×
5315
                extra sqlc.GraphChannelExtraType) error {
×
5316

×
5317
                if extras[extra.ChannelID] == nil {
×
5318
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5319
                }
×
5320
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5321

×
5322
                return nil
×
5323
        }
5324

5325
        return extras, sqldb.ExecuteBatchQuery(
×
5326
                ctx, cfg, channelIDs,
×
5327
                func(id int64) int64 {
×
5328
                        return id
×
5329
                },
×
5330
                func(ctx context.Context,
5331
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5332

×
5333
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5334
                }, cb,
×
5335
        )
5336
}
5337

5338
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5339
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5340
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5341
// a map of TLV type to extra signed field bytes.
5342
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5343
        cfg *sqldb.QueryConfig, db SQLQueries,
5344
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5345

×
5346
        extras := make(map[int64]map[uint64][]byte)
×
5347

×
5348
        return extras, sqldb.ExecuteBatchQuery(
×
5349
                ctx, cfg, policyIDs,
×
5350
                func(id int64) int64 {
×
5351
                        return id
×
5352
                },
×
5353
                func(ctx context.Context, ids []int64) (
5354
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5355

×
5356
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5357
                },
×
5358
                func(ctx context.Context,
5359
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5360

×
5361
                        if extras[row.PolicyID] == nil {
×
5362
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5363
                        }
×
5364
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5365

×
5366
                        return nil
×
5367
                },
5368
        )
5369
}
5370

5371
// forEachNodePaginated executes a paginated query to process each node in the
5372
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5373
// and applies the provided processNode function to each node.
5374
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5375
        db SQLQueries, protocol lnwire.GossipVersion,
5376
        processNode func(context.Context, int64,
5377
                *models.Node) error) error {
×
5378

×
5379
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5380
                limit int32) ([]sqlc.GraphNode, error) {
×
5381

×
5382
                return db.ListNodesPaginated(
×
5383
                        ctx, sqlc.ListNodesPaginatedParams{
×
5384
                                Version: int16(protocol),
×
5385
                                ID:      lastID,
×
5386
                                Limit:   limit,
×
5387
                        },
×
5388
                )
×
5389
        }
×
5390

5391
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5392
                return node.ID
×
5393
        }
×
5394

5395
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5396
                return node.ID, nil
×
5397
        }
×
5398

5399
        batchQueryFunc := func(ctx context.Context,
×
5400
                nodeIDs []int64) (*batchNodeData, error) {
×
5401

×
5402
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5403
        }
×
5404

5405
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5406
                batchData *batchNodeData) error {
×
5407

×
5408
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5409
                if err != nil {
×
5410
                        return fmt.Errorf("unable to build "+
×
5411
                                "node(id=%d): %w", dbNode.ID, err)
×
5412
                }
×
5413

5414
                return processNode(ctx, dbNode.ID, node)
×
5415
        }
5416

5417
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5418
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5419
                collectFunc, batchQueryFunc, processItem,
×
5420
        )
×
5421
}
5422

5423
// forEachChannelWithPolicies executes a paginated query to process each channel
5424
// with policies in the graph.
5425
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5426
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5427
                *models.ChannelEdgePolicy,
5428
                *models.ChannelEdgePolicy) error) error {
×
5429

×
5430
        type channelBatchIDs struct {
×
5431
                channelID int64
×
5432
                policyIDs []int64
×
5433
        }
×
5434

×
5435
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5436
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5437
                error) {
×
5438

×
5439
                return db.ListChannelsWithPoliciesPaginated(
×
5440
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5441
                                Version: int16(lnwire.GossipVersion1),
×
5442
                                ID:      lastID,
×
5443
                                Limit:   limit,
×
5444
                        },
×
5445
                )
×
5446
        }
×
5447

5448
        extractPageCursor := func(
×
5449
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5450

×
5451
                return row.GraphChannel.ID
×
5452
        }
×
5453

5454
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5455
                channelBatchIDs, error) {
×
5456

×
5457
                ids := channelBatchIDs{
×
5458
                        channelID: row.GraphChannel.ID,
×
5459
                }
×
5460

×
5461
                // Extract policy IDs from the row.
×
5462
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5463
                if err != nil {
×
5464
                        return ids, err
×
5465
                }
×
5466

5467
                if dbPol1 != nil {
×
5468
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5469
                }
×
5470
                if dbPol2 != nil {
×
5471
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5472
                }
×
5473

5474
                return ids, nil
×
5475
        }
5476

5477
        batchDataFunc := func(ctx context.Context,
×
5478
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5479

×
5480
                // Separate channel IDs from policy IDs.
×
5481
                var (
×
5482
                        channelIDs = make([]int64, len(allIDs))
×
5483
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5484
                )
×
5485

×
5486
                for i, ids := range allIDs {
×
5487
                        channelIDs[i] = ids.channelID
×
5488
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5489
                }
×
5490

5491
                return batchLoadChannelData(
×
5492
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5493
                )
×
5494
        }
5495

5496
        processItem := func(ctx context.Context,
×
5497
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5498
                batchData *batchChannelData) error {
×
5499

×
5500
                node1, node2, err := buildNodeVertices(
×
5501
                        row.Node1Pubkey, row.Node2Pubkey,
×
5502
                )
×
5503
                if err != nil {
×
5504
                        return err
×
5505
                }
×
5506

5507
                edge, err := buildEdgeInfoWithBatchData(
×
5508
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5509
                        batchData,
×
5510
                )
×
5511
                if err != nil {
×
5512
                        return fmt.Errorf("unable to build channel info: %w",
×
5513
                                err)
×
5514
                }
×
5515

5516
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5517
                if err != nil {
×
5518
                        return err
×
5519
                }
×
5520

5521
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5522
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5523
                )
×
5524
                if err != nil {
×
5525
                        return err
×
5526
                }
×
5527

5528
                return processChannel(edge, p1, p2)
×
5529
        }
5530

5531
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5532
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5533
                collectFunc, batchDataFunc, processItem,
×
5534
        )
×
5535
}
5536

5537
// buildDirectedChannel builds a DirectedChannel instance from the provided
5538
// data.
5539
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5540
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5541
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5542
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5543

×
5544
        node1, node2, err := buildNodeVertices(
×
5545
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5546
        )
×
5547
        if err != nil {
×
5548
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5549
        }
×
5550

5551
        edge, err := buildEdgeInfoWithBatchData(
×
5552
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5553
        )
×
5554
        if err != nil {
×
5555
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5556
        }
×
5557

5558
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5559
        if err != nil {
×
5560
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5561
                        err)
×
5562
        }
×
5563

5564
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5565
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5566
                channelBatchData,
×
5567
        )
×
5568
        if err != nil {
×
5569
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5570
                        err)
×
5571
        }
×
5572

5573
        // Determine outgoing and incoming policy for this specific node.
5574
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5575
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5576
        outPolicy, inPolicy := p1, p2
×
5577
        if (p1 != nil && p1ToNode == nodeID) ||
×
5578
                (p2 != nil && p2ToNode != nodeID) {
×
5579

×
5580
                outPolicy, inPolicy = p2, p1
×
5581
        }
×
5582

5583
        // Build cached policy.
5584
        var cachedInPolicy *models.CachedEdgePolicy
×
5585
        if inPolicy != nil {
×
5586
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5587
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5588
                cachedInPolicy.ToNodeFeatures = features
×
5589
        }
×
5590

5591
        // Extract inbound fee.
5592
        var inboundFee lnwire.Fee
×
5593
        if outPolicy != nil {
×
5594
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5595
                        inboundFee = fee
×
5596
                })
×
5597
        }
5598

5599
        // Build directed channel.
5600
        directedChannel := &DirectedChannel{
×
5601
                ChannelID:    edge.ChannelID,
×
5602
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5603
                OtherNode:    edge.NodeKey2Bytes,
×
5604
                Capacity:     edge.Capacity,
×
5605
                OutPolicySet: outPolicy != nil,
×
5606
                InPolicy:     cachedInPolicy,
×
5607
                InboundFee:   inboundFee,
×
5608
        }
×
5609

×
5610
        if nodePub == edge.NodeKey2Bytes {
×
5611
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5612
        }
×
5613

5614
        return directedChannel, nil
×
5615
}
5616

5617
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5618
// provided rows. It uses batch loading for channels, policies, and nodes.
5619
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5620
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5621

×
5622
        var (
×
5623
                channelIDs = make([]int64, len(rows))
×
5624
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5625
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5626

×
5627
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5628
                nodeIDSet = make(map[int64]bool)
×
5629

×
5630
                // edges will hold the final channel edges built from the rows.
×
5631
                edges = make([]ChannelEdge, 0, len(rows))
×
5632
        )
×
5633

×
5634
        // Collect all IDs needed for batch loading.
×
5635
        for i, row := range rows {
×
5636
                channelIDs[i] = row.Channel().ID
×
5637

×
5638
                // Collect policy IDs
×
5639
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5640
                if err != nil {
×
5641
                        return nil, fmt.Errorf("unable to extract channel "+
×
5642
                                "policies: %w", err)
×
5643
                }
×
5644
                if dbPol1 != nil {
×
5645
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5646
                }
×
5647
                if dbPol2 != nil {
×
5648
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5649
                }
×
5650

5651
                var (
×
5652
                        node1ID = row.Node1().ID
×
5653
                        node2ID = row.Node2().ID
×
5654
                )
×
5655

×
5656
                // Collect unique node IDs.
×
5657
                if !nodeIDSet[node1ID] {
×
5658
                        nodeIDs = append(nodeIDs, node1ID)
×
5659
                        nodeIDSet[node1ID] = true
×
5660
                }
×
5661

5662
                if !nodeIDSet[node2ID] {
×
5663
                        nodeIDs = append(nodeIDs, node2ID)
×
5664
                        nodeIDSet[node2ID] = true
×
5665
                }
×
5666
        }
5667

5668
        // Batch the data for all the channels and policies.
5669
        channelBatchData, err := batchLoadChannelData(
×
5670
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5671
        )
×
5672
        if err != nil {
×
5673
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5674
                        "policy data: %w", err)
×
5675
        }
×
5676

5677
        // Batch the data for all the nodes.
5678
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5679
        if err != nil {
×
5680
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5681
                        err)
×
5682
        }
×
5683

5684
        // Build all channel edges using batch data.
5685
        for _, row := range rows {
×
5686
                // Build nodes using batch data.
×
5687
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5688
                if err != nil {
×
5689
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5690
                }
×
5691

5692
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5693
                if err != nil {
×
5694
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5695
                }
×
5696

5697
                // Build channel info using batch data.
5698
                channel, err := buildEdgeInfoWithBatchData(
×
5699
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5700
                        node2.PubKeyBytes, channelBatchData,
×
5701
                )
×
5702
                if err != nil {
×
5703
                        return nil, fmt.Errorf("unable to build channel "+
×
5704
                                "info: %w", err)
×
5705
                }
×
5706

5707
                // Extract and build policies using batch data.
5708
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5709
                if err != nil {
×
5710
                        return nil, fmt.Errorf("unable to extract channel "+
×
5711
                                "policies: %w", err)
×
5712
                }
×
5713

5714
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5715
                        dbPol1, dbPol2, channel.ChannelID,
×
5716
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5717
                )
×
5718
                if err != nil {
×
5719
                        return nil, fmt.Errorf("unable to build channel "+
×
5720
                                "policies: %w", err)
×
5721
                }
×
5722

5723
                edges = append(edges, ChannelEdge{
×
5724
                        Info:    channel,
×
5725
                        Policy1: p1,
×
5726
                        Policy2: p2,
×
5727
                        Node1:   node1,
×
5728
                        Node2:   node2,
×
5729
                })
×
5730
        }
5731

5732
        return edges, nil
×
5733
}
5734

5735
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5736
// instances from the provided rows using batch loading for channel data.
5737
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5738
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5739
        []*models.ChannelEdgeInfo, []int64, error) {
×
5740

×
5741
        if len(rows) == 0 {
×
5742
                return nil, nil, nil
×
5743
        }
×
5744

5745
        // Collect all the channel IDs needed for batch loading.
5746
        channelIDs := make([]int64, len(rows))
×
5747
        for i, row := range rows {
×
5748
                channelIDs[i] = row.Channel().ID
×
5749
        }
×
5750

5751
        // Batch load the channel data.
5752
        channelBatchData, err := batchLoadChannelData(
×
5753
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5754
        )
×
5755
        if err != nil {
×
5756
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5757
                        "data: %w", err)
×
5758
        }
×
5759

5760
        // Build all channel edges using batch data.
5761
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5762
        for _, row := range rows {
×
5763
                node1, node2, err := buildNodeVertices(
×
5764
                        row.Node1Pub(), row.Node2Pub(),
×
5765
                )
×
5766
                if err != nil {
×
5767
                        return nil, nil, err
×
5768
                }
×
5769

5770
                // Build channel info using batch data
5771
                info, err := buildEdgeInfoWithBatchData(
×
5772
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5773
                        channelBatchData,
×
5774
                )
×
5775
                if err != nil {
×
5776
                        return nil, nil, err
×
5777
                }
×
5778

5779
                edges = append(edges, info)
×
5780
        }
5781

5782
        return edges, channelIDs, nil
×
5783
}
5784

5785
// handleZombieMarking is a helper function that handles the logic of
5786
// marking a channel as a zombie in the database. It takes into account whether
5787
// we are in strict zombie pruning mode, and adjusts the node public keys
5788
// accordingly based on the last update timestamps of the channel policies.
5789
func handleZombieMarking(ctx context.Context, db SQLQueries,
5790
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5791
        strictZombiePruning bool, scid uint64) error {
×
5792

×
5793
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5794

×
5795
        if strictZombiePruning {
×
5796
                var e1UpdateTime, e2UpdateTime *time.Time
×
5797
                if row.Policy1LastUpdate.Valid {
×
5798
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5799
                        e1UpdateTime = &e1Time
×
5800
                }
×
5801
                if row.Policy2LastUpdate.Valid {
×
5802
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5803
                        e2UpdateTime = &e2Time
×
5804
                }
×
5805

5806
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5807
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5808
                        e2UpdateTime,
×
5809
                )
×
5810
        }
5811

5812
        return db.UpsertZombieChannel(
×
5813
                ctx, sqlc.UpsertZombieChannelParams{
×
5814
                        Version:  int16(lnwire.GossipVersion1),
×
5815
                        Scid:     channelIDToBytes(scid),
×
5816
                        NodeKey1: nodeKey1[:],
×
5817
                        NodeKey2: nodeKey2[:],
×
5818
                },
×
5819
        )
×
5820
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc