• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 20064612575

09 Dec 2025 01:09PM UTC coverage: 65.208% (+0.002%) from 65.206%
20064612575

push

github

web-flow
Merge pull request #10428 from ziggie1984/fix-sql-pool-exhaustion

graphdb: fix potential sql tx exhaustion

7 of 14 new or added lines in 2 files covered. (50.0%)

91 existing lines in 20 files now uncovered.

137718 of 211199 relevant lines covered (65.21%)

20721.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
46
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
47
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
48
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
49
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
50
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
51
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
52
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
53
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
54
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
55
        DeleteNode(ctx context.Context, id int64) error
56

57
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
58
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
59
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
60
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
61

62
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
63
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
64
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
65
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
66

67
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
68
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
69
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
70
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
71
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
72

73
        /*
74
                Source node queries.
75
        */
76
        AddSourceNode(ctx context.Context, nodeID int64) error
77
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
78

79
        /*
80
                Channel queries.
81
        */
82
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
83
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
84
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
85
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
86
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
87
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
88
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
89
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
90
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
91
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
92
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
93
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
94
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
95
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
96
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
97
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
98
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
99
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
100
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
101
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
102
        DeleteChannels(ctx context.Context, ids []int64) error
103

104
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
105
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
106
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
107
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
108

109
        /*
110
                Channel Policy table queries.
111
        */
112
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
113
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
114
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
115

116
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
117
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
118
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
119

120
        /*
121
                Zombie index queries.
122
        */
123
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
124
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
125
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
126
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
127
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
128
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
129

130
        /*
131
                Prune log table queries.
132
        */
133
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
134
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
135
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
136
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
137
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
138

139
        /*
140
                Closed SCID table queries.
141
        */
142
        InsertClosedChannel(ctx context.Context, scid []byte) error
143
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
144
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
145

146
        /*
147
                Migration specific queries.
148

149
                NOTE: these should not be used in code other than migrations.
150
                Once sqldbv2 is in place, these can be removed from this struct
151
                as then migrations will have their own dedicated queries
152
                structs.
153
        */
154
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
155
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
156
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
157
}
158

159
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
160
// database operations.
161
type BatchedSQLQueries interface {
162
        SQLQueries
163
        sqldb.BatchedTx[SQLQueries]
164
}
165

166
// SQLStore is an implementation of the V1Store interface that uses a SQL
167
// database as the backend.
168
type SQLStore struct {
169
        cfg *SQLStoreConfig
170
        db  BatchedSQLQueries
171

172
        // cacheMu guards all caches (rejectCache and chanCache). If
173
        // this mutex will be acquired at the same time as the DB mutex then
174
        // the cacheMu MUST be acquired first to prevent deadlock.
175
        cacheMu     sync.RWMutex
176
        rejectCache *rejectCache
177
        chanCache   *channelCache
178

179
        chanScheduler batch.Scheduler[SQLQueries]
180
        nodeScheduler batch.Scheduler[SQLQueries]
181

182
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
183
        srcNodeMu sync.Mutex
184
}
185

186
// A compile-time assertion to ensure that SQLStore implements the V1Store
187
// interface.
188
var _ V1Store = (*SQLStore)(nil)
189

190
// SQLStoreConfig holds the configuration for the SQLStore.
191
type SQLStoreConfig struct {
192
        // ChainHash is the genesis hash for the chain that all the gossip
193
        // messages in this store are aimed at.
194
        ChainHash chainhash.Hash
195

196
        // QueryConfig holds configuration values for SQL queries.
197
        QueryCfg *sqldb.QueryConfig
198
}
199

200
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
201
// storage backend.
202
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
203
        options ...StoreOptionModifier) (*SQLStore, error) {
×
204

×
205
        opts := DefaultOptions()
×
206
        for _, o := range options {
×
207
                o(opts)
×
208
        }
×
209

210
        if opts.NoMigration {
×
211
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
212
                        "supported for SQL stores")
×
213
        }
×
214

215
        s := &SQLStore{
×
216
                cfg:         cfg,
×
217
                db:          db,
×
218
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
219
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
220
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
221
        }
×
222

×
223
        s.chanScheduler = batch.NewTimeScheduler(
×
224
                db, &s.cacheMu, opts.BatchCommitInterval,
×
225
        )
×
226
        s.nodeScheduler = batch.NewTimeScheduler(
×
227
                db, nil, opts.BatchCommitInterval,
×
228
        )
×
229

×
230
        return s, nil
×
231
}
232

233
// AddNode adds a vertex/node to the graph database. If the node is not
234
// in the database from before, this will add a new, unconnected one to the
235
// graph. If it is present from before, this will update that node's
236
// information.
237
//
238
// NOTE: part of the V1Store interface.
239
func (s *SQLStore) AddNode(ctx context.Context,
240
        node *models.Node, opts ...batch.SchedulerOption) error {
×
241

×
242
        r := &batch.Request[SQLQueries]{
×
243
                Opts: batch.NewSchedulerOptions(opts...),
×
244
                Do: func(queries SQLQueries) error {
×
245
                        _, err := upsertNode(ctx, queries, node)
×
246

×
247
                        // It is possible that two of the same node
×
248
                        // announcements are both being processed in the same
×
249
                        // batch. This may case the UpsertNode conflict to
×
250
                        // be hit since we require at the db layer that the
×
251
                        // new last_update is greater than the existing
×
252
                        // last_update. We need to gracefully handle this here.
×
253
                        if errors.Is(err, sql.ErrNoRows) {
×
254
                                return nil
×
255
                        }
×
256

257
                        return err
×
258
                },
259
        }
260

261
        return s.nodeScheduler.Execute(ctx, r)
×
262
}
263

264
// FetchNode attempts to look up a target node by its identity public
265
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
266
// returned.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) FetchNode(ctx context.Context,
270
        pubKey route.Vertex) (*models.Node, error) {
×
271

×
272
        var node *models.Node
×
273
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
274
                var err error
×
275
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
276

×
277
                return err
×
278
        }, sqldb.NoOpReset)
×
279
        if err != nil {
×
280
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
281
        }
×
282

283
        return node, nil
×
284
}
285

286
// HasNode determines if the graph has a vertex identified by the
287
// target node identity public key. If the node exists in the database, a
288
// timestamp of when the data for the node was lasted updated is returned along
289
// with a true boolean. Otherwise, an empty time.Time is returned with a false
290
// boolean.
291
//
292
// NOTE: part of the V1Store interface.
293
func (s *SQLStore) HasNode(ctx context.Context,
294
        pubKey [33]byte) (time.Time, bool, error) {
×
295

×
296
        var (
×
297
                exists     bool
×
298
                lastUpdate time.Time
×
299
        )
×
300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
301
                dbNode, err := db.GetNodeByPubKey(
×
302
                        ctx, sqlc.GetNodeByPubKeyParams{
×
303
                                Version: int16(lnwire.GossipVersion1),
×
304
                                PubKey:  pubKey[:],
×
305
                        },
×
306
                )
×
307
                if errors.Is(err, sql.ErrNoRows) {
×
308
                        return nil
×
309
                } else if err != nil {
×
310
                        return fmt.Errorf("unable to fetch node: %w", err)
×
311
                }
×
312

313
                exists = true
×
314

×
315
                if dbNode.LastUpdate.Valid {
×
316
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
317
                }
×
318

319
                return nil
×
320
        }, sqldb.NoOpReset)
321
        if err != nil {
×
322
                return time.Time{}, false,
×
323
                        fmt.Errorf("unable to fetch node: %w", err)
×
324
        }
×
325

326
        return lastUpdate, exists, nil
×
327
}
328

329
// AddrsForNode returns all known addresses for the target node public key
330
// that the graph DB is aware of. The returned boolean indicates if the
331
// given node is unknown to the graph DB or not.
332
//
333
// NOTE: part of the V1Store interface.
334
func (s *SQLStore) AddrsForNode(ctx context.Context,
335
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
336

×
337
        var (
×
338
                addresses []net.Addr
×
339
                known     bool
×
340
        )
×
341
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
342
                // First, check if the node exists and get its DB ID if it
×
343
                // does.
×
344
                dbID, err := db.GetNodeIDByPubKey(
×
345
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
346
                                Version: int16(lnwire.GossipVersion1),
×
347
                                PubKey:  nodePub.SerializeCompressed(),
×
348
                        },
×
349
                )
×
350
                if errors.Is(err, sql.ErrNoRows) {
×
351
                        return nil
×
352
                }
×
353

354
                known = true
×
355

×
356
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
357
                if err != nil {
×
358
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
359
                                err)
×
360
                }
×
361

362
                return nil
×
363
        }, sqldb.NoOpReset)
364
        if err != nil {
×
365
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
366
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
367
        }
×
368

369
        return known, addresses, nil
×
370
}
371

372
// DeleteNode starts a new database transaction to remove a vertex/node
373
// from the database according to the node's public key.
374
//
375
// NOTE: part of the V1Store interface.
376
func (s *SQLStore) DeleteNode(ctx context.Context,
377
        pubKey route.Vertex) error {
×
378

×
379
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
380
                res, err := db.DeleteNodeByPubKey(
×
381
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
382
                                Version: int16(lnwire.GossipVersion1),
×
383
                                PubKey:  pubKey[:],
×
384
                        },
×
385
                )
×
386
                if err != nil {
×
387
                        return err
×
388
                }
×
389

390
                rows, err := res.RowsAffected()
×
391
                if err != nil {
×
392
                        return err
×
393
                }
×
394

395
                if rows == 0 {
×
396
                        return ErrGraphNodeNotFound
×
397
                } else if rows > 1 {
×
398
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
399
                }
×
400

401
                return err
×
402
        }, sqldb.NoOpReset)
403
        if err != nil {
×
404
                return fmt.Errorf("unable to delete node: %w", err)
×
405
        }
×
406

407
        return nil
×
408
}
409

410
// FetchNodeFeatures returns the features of the given node. If no features are
411
// known for the node, an empty feature vector is returned.
412
//
413
// NOTE: this is part of the graphdb.NodeTraverser interface.
414
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
415
        *lnwire.FeatureVector, error) {
×
416

×
417
        ctx := context.TODO()
×
418

×
419
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
420
}
×
421

422
// DisabledChannelIDs returns the channel ids of disabled channels.
423
// A channel is disabled when two of the associated ChanelEdgePolicies
424
// have their disabled bit on.
425
//
426
// NOTE: part of the V1Store interface.
427
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
428
        var (
×
429
                ctx     = context.TODO()
×
430
                chanIDs []uint64
×
431
        )
×
432
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
433
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
434
                if err != nil {
×
435
                        return fmt.Errorf("unable to fetch disabled "+
×
436
                                "channels: %w", err)
×
437
                }
×
438

439
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
445
                        err)
×
446
        }
×
447

448
        return chanIDs, nil
×
449
}
450

451
// LookupAlias attempts to return the alias as advertised by the target node.
452
//
453
// NOTE: part of the V1Store interface.
454
func (s *SQLStore) LookupAlias(ctx context.Context,
455
        pub *btcec.PublicKey) (string, error) {
×
456

×
457
        var alias string
×
458
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
459
                dbNode, err := db.GetNodeByPubKey(
×
460
                        ctx, sqlc.GetNodeByPubKeyParams{
×
461
                                Version: int16(lnwire.GossipVersion1),
×
462
                                PubKey:  pub.SerializeCompressed(),
×
463
                        },
×
464
                )
×
465
                if errors.Is(err, sql.ErrNoRows) {
×
466
                        return ErrNodeAliasNotFound
×
467
                } else if err != nil {
×
468
                        return fmt.Errorf("unable to fetch node: %w", err)
×
469
                }
×
470

471
                if !dbNode.Alias.Valid {
×
472
                        return ErrNodeAliasNotFound
×
473
                }
×
474

475
                alias = dbNode.Alias.String
×
476

×
477
                return nil
×
478
        }, sqldb.NoOpReset)
479
        if err != nil {
×
480
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
481
        }
×
482

483
        return alias, nil
×
484
}
485

486
// SourceNode returns the source node of the graph. The source node is treated
487
// as the center node within a star-graph. This method may be used to kick off
488
// a path finding algorithm in order to explore the reachability of another
489
// node based off the source node.
490
//
491
// NOTE: part of the V1Store interface.
492
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
493
        error) {
×
494

×
495
        var node *models.Node
×
496
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
497
                _, nodePub, err := s.getSourceNode(
×
498
                        ctx, db, lnwire.GossipVersion1,
×
499
                )
×
500
                if err != nil {
×
501
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
502
                                err)
×
503
                }
×
504

505
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
506

×
507
                return err
×
508
        }, sqldb.NoOpReset)
509
        if err != nil {
×
510
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
511
        }
×
512

513
        return node, nil
×
514
}
515

516
// SetSourceNode sets the source node within the graph database. The source
517
// node is to be used as the center of a star-graph within path finding
518
// algorithms.
519
//
520
// NOTE: part of the V1Store interface.
521
func (s *SQLStore) SetSourceNode(ctx context.Context,
522
        node *models.Node) error {
×
523

×
524
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
525
                // For the source node, we use a less strict upsert that allows
×
526
                // updates even when the timestamp hasn't changed. This handles
×
527
                // the race condition where multiple goroutines (e.g.,
×
528
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
529
                // same old timestamp, independently increment it, and try to
×
530
                // write concurrently. We want all parameter changes to persist,
×
531
                // even if timestamps collide.
×
532
                id, err := upsertSourceNode(ctx, db, node)
×
533
                if err != nil {
×
534
                        return fmt.Errorf("unable to upsert source node: %w",
×
535
                                err)
×
536
                }
×
537

538
                // Make sure that if a source node for this version is already
539
                // set, then the ID is the same as the one we are about to set.
540
                dbSourceNodeID, _, err := s.getSourceNode(
×
541
                        ctx, db, lnwire.GossipVersion1,
×
542
                )
×
543
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
544
                        return fmt.Errorf("unable to fetch source node: %w",
×
545
                                err)
×
546
                } else if err == nil {
×
547
                        if dbSourceNodeID != id {
×
548
                                return fmt.Errorf("v1 source node already "+
×
549
                                        "set to a different node: %d vs %d",
×
550
                                        dbSourceNodeID, id)
×
551
                        }
×
552

553
                        return nil
×
554
                }
555

556
                return db.AddSourceNode(ctx, id)
×
557
        }, sqldb.NoOpReset)
558
}
559

560
// NodeUpdatesInHorizon returns all the known lightning node which have an
561
// update timestamp within the passed range. This method can be used by two
562
// nodes to quickly determine if they have the same set of up to date node
563
// announcements.
564
//
565
// NOTE: This is part of the V1Store interface.
566
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
567
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
568

×
569
        cfg := defaultIteratorConfig()
×
570
        for _, opt := range opts {
×
571
                opt(cfg)
×
572
        }
×
573

574
        return func(yield func(*models.Node, error) bool) {
×
575
                var (
×
576
                        ctx            = context.TODO()
×
577
                        lastUpdateTime sql.NullInt64
×
578
                        lastPubKey     = make([]byte, 33)
×
579
                        hasMore        = true
×
580
                )
×
581

×
582
                // Each iteration, we'll read a batch amount of nodes, yield
×
583
                // them, then decide is we have more or not.
×
584
                for hasMore {
×
585
                        var batch []*models.Node
×
586

×
587
                        //nolint:ll
×
588
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
589
                                //nolint:ll
×
590
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
591
                                        StartTime: sqldb.SQLInt64(
×
592
                                                startTime.Unix(),
×
593
                                        ),
×
594
                                        EndTime: sqldb.SQLInt64(
×
595
                                                endTime.Unix(),
×
596
                                        ),
×
597
                                        LastUpdate: lastUpdateTime,
×
598
                                        LastPubKey: lastPubKey,
×
599
                                        OnlyPublic: sql.NullBool{
×
600
                                                Bool:  cfg.iterPublicNodes,
×
601
                                                Valid: true,
×
602
                                        },
×
603
                                        MaxResults: sqldb.SQLInt32(
×
604
                                                cfg.nodeUpdateIterBatchSize,
×
605
                                        ),
×
606
                                }
×
607
                                rows, err := db.GetNodesByLastUpdateRange(
×
608
                                        ctx, params,
×
609
                                )
×
610
                                if err != nil {
×
611
                                        return err
×
612
                                }
×
613

614
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
615

×
616
                                err = forEachNodeInBatch(
×
617
                                        ctx, s.cfg.QueryCfg, db, rows,
×
618
                                        func(_ int64, node *models.Node) error {
×
619
                                                batch = append(batch, node)
×
620

×
621
                                                // Update pagination cursors
×
622
                                                // based on the last processed
×
623
                                                // node.
×
624
                                                lastUpdateTime = sql.NullInt64{
×
625
                                                        Int64: node.LastUpdate.
×
626
                                                                Unix(),
×
627
                                                        Valid: true,
×
628
                                                }
×
629
                                                lastPubKey = node.PubKeyBytes[:]
×
630

×
631
                                                return nil
×
632
                                        },
×
633
                                )
634
                                if err != nil {
×
635
                                        return fmt.Errorf("unable to build "+
×
636
                                                "nodes: %w", err)
×
637
                                }
×
638

639
                                return nil
×
640
                        }, func() {
×
641
                                batch = []*models.Node{}
×
642
                        })
×
643

644
                        if err != nil {
×
645
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
646
                                        "error: %v", err)
×
647

×
648
                                yield(&models.Node{}, err)
×
649

×
650
                                return
×
651
                        }
×
652

653
                        for _, node := range batch {
×
654
                                if !yield(node, nil) {
×
655
                                        return
×
656
                                }
×
657
                        }
658

659
                        // If the batch didn't yield anything, then we're done.
660
                        if len(batch) == 0 {
×
661
                                break
×
662
                        }
663
                }
664
        }
665
}
666

667
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
668
// undirected edge from the two target nodes are created. The information stored
669
// denotes the static attributes of the channel, such as the channelID, the keys
670
// involved in creation of the channel, and the set of features that the channel
671
// supports. The chanPoint and chanID are used to uniquely identify the edge
672
// globally within the database.
673
//
674
// NOTE: part of the V1Store interface.
675
func (s *SQLStore) AddChannelEdge(ctx context.Context,
676
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
677

×
678
        var alreadyExists bool
×
679
        r := &batch.Request[SQLQueries]{
×
680
                Opts: batch.NewSchedulerOptions(opts...),
×
681
                Reset: func() {
×
682
                        alreadyExists = false
×
683
                },
×
684
                Do: func(tx SQLQueries) error {
×
685
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
686

×
687
                        // Make sure that the channel doesn't already exist. We
×
688
                        // do this explicitly instead of relying on catching a
×
689
                        // unique constraint error because relying on SQL to
×
690
                        // throw that error would abort the entire batch of
×
691
                        // transactions.
×
692
                        _, err := tx.GetChannelBySCID(
×
693
                                ctx, sqlc.GetChannelBySCIDParams{
×
694
                                        Scid:    chanIDB,
×
695
                                        Version: int16(lnwire.GossipVersion1),
×
696
                                },
×
697
                        )
×
698
                        if err == nil {
×
699
                                alreadyExists = true
×
700
                                return nil
×
701
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
702
                                return fmt.Errorf("unable to fetch channel: %w",
×
703
                                        err)
×
704
                        }
×
705

706
                        return insertChannel(ctx, tx, edge)
×
707
                },
708
                OnCommit: func(err error) error {
×
709
                        switch {
×
710
                        case err != nil:
×
711
                                return err
×
712
                        case alreadyExists:
×
713
                                return ErrEdgeAlreadyExist
×
714
                        default:
×
715
                                s.rejectCache.remove(edge.ChannelID)
×
716
                                s.chanCache.remove(edge.ChannelID)
×
717
                                return nil
×
718
                        }
719
                },
720
        }
721

722
        return s.chanScheduler.Execute(ctx, r)
×
723
}
724

725
// HighestChanID returns the "highest" known channel ID in the channel graph.
726
// This represents the "newest" channel from the PoV of the chain. This method
727
// can be used by peers to quickly determine if their graphs are in sync.
728
//
729
// NOTE: This is part of the V1Store interface.
730
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
731
        var highestChanID uint64
×
732
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
734
                if errors.Is(err, sql.ErrNoRows) {
×
735
                        return nil
×
736
                } else if err != nil {
×
737
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
738
                                err)
×
739
                }
×
740

741
                highestChanID = byteOrder.Uint64(chanID)
×
742

×
743
                return nil
×
744
        }, sqldb.NoOpReset)
745
        if err != nil {
×
746
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
747
        }
×
748

749
        return highestChanID, nil
×
750
}
751

752
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
753
// within the database for the referenced channel. The `flags` attribute within
754
// the ChannelEdgePolicy determines which of the directed edges are being
755
// updated. If the flag is 1, then the first node's information is being
756
// updated, otherwise it's the second node's information. The node ordering is
757
// determined by the lexicographical ordering of the identity public keys of the
758
// nodes on either side of the channel.
759
//
760
// NOTE: part of the V1Store interface.
761
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
762
        edge *models.ChannelEdgePolicy,
763
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
764

×
765
        var (
×
766
                isUpdate1    bool
×
767
                edgeNotFound bool
×
768
                from, to     route.Vertex
×
769
        )
×
770

×
771
        r := &batch.Request[SQLQueries]{
×
772
                Opts: batch.NewSchedulerOptions(opts...),
×
773
                Reset: func() {
×
774
                        isUpdate1 = false
×
775
                        edgeNotFound = false
×
776
                },
×
777
                Do: func(tx SQLQueries) error {
×
778
                        var err error
×
779
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
780
                                ctx, tx, edge,
×
781
                        )
×
782
                        // It is possible that two of the same policy
×
783
                        // announcements are both being processed in the same
×
784
                        // batch. This may case the UpsertEdgePolicy conflict to
×
785
                        // be hit since we require at the db layer that the
×
786
                        // new last_update is greater than the existing
×
787
                        // last_update. We need to gracefully handle this here.
×
788
                        if errors.Is(err, sql.ErrNoRows) {
×
789
                                return nil
×
790
                        } else if err != nil {
×
791
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
792
                        }
×
793

794
                        // Silence ErrEdgeNotFound so that the batch can
795
                        // succeed, but propagate the error via local state.
796
                        if errors.Is(err, ErrEdgeNotFound) {
×
797
                                edgeNotFound = true
×
798
                                return nil
×
799
                        }
×
800

801
                        return err
×
802
                },
803
                OnCommit: func(err error) error {
×
804
                        switch {
×
805
                        case err != nil:
×
806
                                return err
×
807
                        case edgeNotFound:
×
808
                                return ErrEdgeNotFound
×
809
                        default:
×
810
                                s.updateEdgeCache(edge, isUpdate1)
×
811
                                return nil
×
812
                        }
813
                },
814
        }
815

816
        err := s.chanScheduler.Execute(ctx, r)
×
817

×
818
        return from, to, err
×
819
}
820

821
// updateEdgeCache updates our reject and channel caches with the new
822
// edge policy information.
823
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
824
        isUpdate1 bool) {
×
825

×
826
        // If an entry for this channel is found in reject cache, we'll modify
×
827
        // the entry with the updated timestamp for the direction that was just
×
828
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
829
        // during the next query for this edge.
×
830
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
831
                if isUpdate1 {
×
832
                        entry.upd1Time = e.LastUpdate.Unix()
×
833
                } else {
×
834
                        entry.upd2Time = e.LastUpdate.Unix()
×
835
                }
×
836
                s.rejectCache.insert(e.ChannelID, entry)
×
837
        }
838

839
        // If an entry for this channel is found in channel cache, we'll modify
840
        // the entry with the updated policy for the direction that was just
841
        // written. If the edge doesn't exist, we'll defer loading the info and
842
        // policies and lazily read from disk during the next query.
843
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
844
                if isUpdate1 {
×
845
                        channel.Policy1 = e
×
846
                } else {
×
847
                        channel.Policy2 = e
×
848
                }
×
849
                s.chanCache.insert(e.ChannelID, channel)
×
850
        }
851
}
852

853
// ForEachSourceNodeChannel iterates through all channels of the source node,
854
// executing the passed callback on each. The call-back is provided with the
855
// channel's outpoint, whether we have a policy for the channel and the channel
856
// peer's node information.
857
//
858
// NOTE: part of the V1Store interface.
859
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
860
        cb func(chanPoint wire.OutPoint, havePolicy bool,
861
                otherNode *models.Node) error, reset func()) error {
×
862

×
863
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
864
                nodeID, nodePub, err := s.getSourceNode(
×
865
                        ctx, db, lnwire.GossipVersion1,
×
866
                )
×
867
                if err != nil {
×
868
                        return fmt.Errorf("unable to fetch source node: %w",
×
869
                                err)
×
870
                }
×
871

872
                return forEachNodeChannel(
×
873
                        ctx, db, s.cfg, nodeID,
×
874
                        func(info *models.ChannelEdgeInfo,
×
875
                                outPolicy *models.ChannelEdgePolicy,
×
876
                                _ *models.ChannelEdgePolicy) error {
×
877

×
878
                                // Fetch the other node.
×
879
                                var (
×
880
                                        otherNodePub [33]byte
×
881
                                        node1        = info.NodeKey1Bytes
×
882
                                        node2        = info.NodeKey2Bytes
×
883
                                )
×
884
                                switch {
×
885
                                case bytes.Equal(node1[:], nodePub[:]):
×
886
                                        otherNodePub = node2
×
887
                                case bytes.Equal(node2[:], nodePub[:]):
×
888
                                        otherNodePub = node1
×
889
                                default:
×
890
                                        return fmt.Errorf("node not " +
×
891
                                                "participating in this channel")
×
892
                                }
893

894
                                _, otherNode, err := getNodeByPubKey(
×
895
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
896
                                )
×
897
                                if err != nil {
×
898
                                        return fmt.Errorf("unable to fetch "+
×
899
                                                "other node(%x): %w",
×
900
                                                otherNodePub, err)
×
901
                                }
×
902

903
                                return cb(
×
904
                                        info.ChannelPoint, outPolicy != nil,
×
905
                                        otherNode,
×
906
                                )
×
907
                        },
908
                )
909
        }, reset)
910
}
911

912
// ForEachNode iterates through all the stored vertices/nodes in the graph,
913
// executing the passed callback with each node encountered. If the callback
914
// returns an error, then the transaction is aborted and the iteration stops
915
// early.
916
//
917
// NOTE: part of the V1Store interface.
918
func (s *SQLStore) ForEachNode(ctx context.Context,
919
        cb func(node *models.Node) error, reset func()) error {
×
920

×
921
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
922
                return forEachNodePaginated(
×
923
                        ctx, s.cfg.QueryCfg, db,
×
924
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
925
                                node *models.Node) error {
×
926

×
927
                                return cb(node)
×
928
                        },
×
929
                )
930
        }, reset)
931
}
932

933
// ForEachNodeDirectedChannel iterates through all channels of a given node,
934
// executing the passed callback on the directed edge representing the channel
935
// and its incoming policy. If the callback returns an error, then the iteration
936
// is halted with the error propagated back up to the caller.
937
//
938
// Unknown policies are passed into the callback as nil values.
939
//
940
// NOTE: this is part of the graphdb.NodeTraverser interface.
941
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
942
        cb func(channel *DirectedChannel) error, reset func()) error {
×
943

×
944
        var ctx = context.TODO()
×
945

×
946
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
947
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
948
        }, reset)
×
949
}
950

951
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
952
// graph, executing the passed callback with each node encountered. If the
953
// callback returns an error, then the transaction is aborted and the iteration
954
// stops early.
955
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
956
        cb func(route.Vertex, *lnwire.FeatureVector) error,
957
        reset func()) error {
×
958

×
959
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
960
                return forEachNodeCacheable(
×
961
                        ctx, s.cfg.QueryCfg, db,
×
962
                        func(_ int64, nodePub route.Vertex,
×
963
                                features *lnwire.FeatureVector) error {
×
964

×
965
                                return cb(nodePub, features)
×
966
                        },
×
967
                )
968
        }, reset)
969
        if err != nil {
×
970
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
971
        }
×
972

973
        return nil
×
974
}
975

976
// ForEachNodeChannel iterates through all channels of the given node,
977
// executing the passed callback with an edge info structure and the policies
978
// of each end of the channel. The first edge policy is the outgoing edge *to*
979
// the connecting node, while the second is the incoming edge *from* the
980
// connecting node. If the callback returns an error, then the iteration is
981
// halted with the error propagated back up to the caller.
982
//
983
// Unknown policies are passed into the callback as nil values.
984
//
985
// NOTE: part of the V1Store interface.
986
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
987
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
988
                *models.ChannelEdgePolicy) error, reset func()) error {
×
989

×
990
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
991
                dbNode, err := db.GetNodeByPubKey(
×
992
                        ctx, sqlc.GetNodeByPubKeyParams{
×
993
                                Version: int16(lnwire.GossipVersion1),
×
994
                                PubKey:  nodePub[:],
×
995
                        },
×
996
                )
×
997
                if errors.Is(err, sql.ErrNoRows) {
×
998
                        return nil
×
999
                } else if err != nil {
×
1000
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1001
                }
×
1002

1003
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1004
        }, reset)
1005
}
1006

1007
// extractMaxUpdateTime returns the maximum of the two policy update times.
1008
// This is used for pagination cursor tracking.
1009
func extractMaxUpdateTime(
1010
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1011

×
1012
        switch {
×
1013
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1014
                return max(row.Policy1LastUpdate.Int64,
×
1015
                        row.Policy2LastUpdate.Int64)
×
1016
        case row.Policy1LastUpdate.Valid:
×
1017
                return row.Policy1LastUpdate.Int64
×
1018
        case row.Policy2LastUpdate.Valid:
×
1019
                return row.Policy2LastUpdate.Int64
×
1020
        default:
×
1021
                return 0
×
1022
        }
1023
}
1024

1025
// buildChannelFromRow constructs a ChannelEdge from a database row.
1026
// This includes building the nodes, channel info, and policies.
1027
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1028
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1029

×
1030
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1031
        if err != nil {
×
1032
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1033
                        err)
×
1034
        }
×
1035

1036
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1037
        if err != nil {
×
1038
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1039
                        err)
×
1040
        }
×
1041

1042
        channel, err := getAndBuildEdgeInfo(
×
1043
                ctx, s.cfg, db,
×
1044
                row.GraphChannel, node1.PubKeyBytes,
×
1045
                node2.PubKeyBytes,
×
1046
        )
×
1047
        if err != nil {
×
1048
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1049
                        "channel info: %w", err)
×
1050
        }
×
1051

1052
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1053
        if err != nil {
×
1054
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1055
                        "channel policies: %w", err)
×
1056
        }
×
1057

1058
        p1, p2, err := getAndBuildChanPolicies(
×
1059
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1060
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1061
        )
×
1062
        if err != nil {
×
1063
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1064
                        "channel policies: %w", err)
×
1065
        }
×
1066

1067
        return ChannelEdge{
×
1068
                Info:    channel,
×
1069
                Policy1: p1,
×
1070
                Policy2: p2,
×
1071
                Node1:   node1,
×
1072
                Node2:   node2,
×
1073
        }, nil
×
1074
}
1075

1076
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1077
// This method acquires the cache lock only once for the entire batch.
1078
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1079
        if len(edgesToCache) == 0 {
×
1080
                return
×
1081
        }
×
1082

1083
        s.cacheMu.Lock()
×
1084
        defer s.cacheMu.Unlock()
×
1085

×
1086
        for chanID, edge := range edgesToCache {
×
1087
                s.chanCache.insert(chanID, edge)
×
1088
        }
×
1089
}
1090

1091
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1092
// one edge that has an update timestamp within the specified horizon.
1093
//
1094
// Iterator Lifecycle:
1095
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1096
// 2. Query batch of channels with policies in time range
1097
// 3. For each channel: check if seen, check cache, or build from DB
1098
// 4. Yield channels to caller
1099
// 5. Update cache after successful batch
1100
// 6. Repeat with updated pagination cursor until no more results
1101
//
1102
// NOTE: This is part of the V1Store interface.
1103
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1104
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1105

×
1106
        // Apply options.
×
1107
        cfg := defaultIteratorConfig()
×
1108
        for _, opt := range opts {
×
1109
                opt(cfg)
×
1110
        }
×
1111

1112
        return func(yield func(ChannelEdge, error) bool) {
×
1113
                var (
×
1114
                        ctx            = context.TODO()
×
1115
                        edgesSeen      = make(map[uint64]struct{})
×
1116
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1117
                        hits           int
×
1118
                        total          int
×
1119
                        lastUpdateTime sql.NullInt64
×
1120
                        lastID         sql.NullInt64
×
1121
                        hasMore        = true
×
1122
                )
×
1123

×
1124
                // Each iteration, we'll read a batch amount of channel updates
×
1125
                // (consulting the cache along the way), yield them, then loop
×
1126
                // back to decide if we have any more updates to read out.
×
1127
                for hasMore {
×
1128
                        var batch []ChannelEdge
×
1129

×
NEW
1130
                        // Acquire read lock before starting transaction to
×
NEW
1131
                        // ensure consistent lock ordering (cacheMu -> DB) and
×
NEW
1132
                        // prevent deadlock with write operations.
×
NEW
1133
                        s.cacheMu.RLock()
×
NEW
1134

×
1135
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1136
                                func(db SQLQueries) error {
×
1137
                                        //nolint:ll
×
1138
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1139
                                                Version: int16(lnwire.GossipVersion1),
×
1140
                                                StartTime: sqldb.SQLInt64(
×
1141
                                                        startTime.Unix(),
×
1142
                                                ),
×
1143
                                                EndTime: sqldb.SQLInt64(
×
1144
                                                        endTime.Unix(),
×
1145
                                                ),
×
1146
                                                LastUpdateTime: lastUpdateTime,
×
1147
                                                LastID:         lastID,
×
1148
                                                MaxResults: sql.NullInt32{
×
1149
                                                        Int32: int32(
×
1150
                                                                cfg.chanUpdateIterBatchSize,
×
1151
                                                        ),
×
1152
                                                        Valid: true,
×
1153
                                                },
×
1154
                                        }
×
1155
                                        //nolint:ll
×
1156
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1157
                                                ctx, params,
×
1158
                                        )
×
1159
                                        if err != nil {
×
1160
                                                return err
×
1161
                                        }
×
1162

1163
                                        //nolint:ll
1164
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1165

×
1166
                                        //nolint:ll
×
1167
                                        for _, row := range rows {
×
1168
                                                lastUpdateTime = sql.NullInt64{
×
1169
                                                        Int64: extractMaxUpdateTime(row),
×
1170
                                                        Valid: true,
×
1171
                                                }
×
1172
                                                lastID = sql.NullInt64{
×
1173
                                                        Int64: row.GraphChannel.ID,
×
1174
                                                        Valid: true,
×
1175
                                                }
×
1176

×
1177
                                                // Skip if we've already
×
1178
                                                // processed this channel.
×
1179
                                                chanIDInt := byteOrder.Uint64(
×
1180
                                                        row.GraphChannel.Scid,
×
1181
                                                )
×
1182
                                                _, ok := edgesSeen[chanIDInt]
×
1183
                                                if ok {
×
1184
                                                        continue
×
1185
                                                }
1186

1187
                                                // Check cache (we already hold
1188
                                                // shared read lock).
1189
                                                channel, ok := s.chanCache.get(
×
1190
                                                        chanIDInt,
×
1191
                                                )
×
1192
                                                if ok {
×
1193
                                                        hits++
×
1194
                                                        total++
×
1195
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1196
                                                        batch = append(batch, channel)
×
1197

×
1198
                                                        continue
×
1199
                                                }
1200

1201
                                                chanEdge, err := s.buildChannelFromRow(
×
1202
                                                        ctx, db, row,
×
1203
                                                )
×
1204
                                                if err != nil {
×
1205
                                                        return err
×
1206
                                                }
×
1207

1208
                                                edgesSeen[chanIDInt] = struct{}{}
×
1209
                                                edgesToCache[chanIDInt] = chanEdge
×
1210

×
1211
                                                batch = append(batch, chanEdge)
×
1212

×
1213
                                                total++
×
1214
                                        }
1215

1216
                                        return nil
×
1217
                                }, func() {
×
1218
                                        batch = nil
×
1219
                                        edgesSeen = make(map[uint64]struct{})
×
1220
                                        edgesToCache = make(
×
1221
                                                map[uint64]ChannelEdge,
×
1222
                                        )
×
1223
                                })
×
1224

1225
                        // Release read lock after transaction completes.
NEW
1226
                        s.cacheMu.RUnlock()
×
NEW
1227

×
1228
                        if err != nil {
×
1229
                                log.Errorf("ChanUpdatesInHorizon "+
×
1230
                                        "batch error: %v", err)
×
1231

×
1232
                                yield(ChannelEdge{}, err)
×
1233

×
1234
                                return
×
1235
                        }
×
1236

1237
                        for _, edge := range batch {
×
1238
                                if !yield(edge, nil) {
×
1239
                                        return
×
1240
                                }
×
1241
                        }
1242

1243
                        // Update cache after successful batch yield, setting
1244
                        // the cache lock only once for the entire batch.
1245
                        s.updateChanCacheBatch(edgesToCache)
×
1246
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1247

×
1248
                        // If the batch didn't yield anything, then we're done.
×
1249
                        if len(batch) == 0 {
×
1250
                                break
×
1251
                        }
1252
                }
1253

1254
                if total > 0 {
×
1255
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1256
                                "%.2f (%d/%d)",
×
1257
                                float64(hits)*100/float64(total), hits, total)
×
1258
                } else {
×
1259
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1260
                                "in horizon (%s, %s)", startTime, endTime)
×
1261
                }
×
1262
        }
1263
}
1264

1265
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1266
// data to the call-back. If withAddrs is true, then the call-back will also be
1267
// provided with the addresses associated with the node. The address retrieval
1268
// result in an additional round-trip to the database, so it should only be used
1269
// if the addresses are actually needed.
1270
//
1271
// NOTE: part of the V1Store interface.
1272
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1273
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1274
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1275

×
1276
        type nodeCachedBatchData struct {
×
1277
                features      map[int64][]int
×
1278
                addrs         map[int64][]nodeAddress
×
1279
                chanBatchData *batchChannelData
×
1280
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1281
        }
×
1282

×
1283
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1284
                // pageQueryFunc is used to query the next page of nodes.
×
1285
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1286
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1287

×
1288
                        return db.ListNodeIDsAndPubKeys(
×
1289
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1290
                                        Version: int16(lnwire.GossipVersion1),
×
1291
                                        ID:      lastID,
×
1292
                                        Limit:   limit,
×
1293
                                },
×
1294
                        )
×
1295
                }
×
1296

1297
                // batchDataFunc is then used to batch load the data required
1298
                // for each page of nodes.
1299
                batchDataFunc := func(ctx context.Context,
×
1300
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1301

×
1302
                        // Batch load node features.
×
1303
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1304
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1305
                        )
×
1306
                        if err != nil {
×
1307
                                return nil, fmt.Errorf("unable to batch load "+
×
1308
                                        "node features: %w", err)
×
1309
                        }
×
1310

1311
                        // Maybe fetch the node's addresses if requested.
1312
                        var nodeAddrs map[int64][]nodeAddress
×
1313
                        if withAddrs {
×
1314
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1315
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1316
                                )
×
1317
                                if err != nil {
×
1318
                                        return nil, fmt.Errorf("unable to "+
×
1319
                                                "batch load node "+
×
1320
                                                "addresses: %w", err)
×
1321
                                }
×
1322
                        }
1323

1324
                        // Batch load ALL unique channels for ALL nodes in this
1325
                        // page.
1326
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1327
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1328
                                        Version:  int16(lnwire.GossipVersion1),
×
1329
                                        Node1Ids: nodeIDs,
×
1330
                                        Node2Ids: nodeIDs,
×
1331
                                },
×
1332
                        )
×
1333
                        if err != nil {
×
1334
                                return nil, fmt.Errorf("unable to batch "+
×
1335
                                        "fetch channels for nodes: %w", err)
×
1336
                        }
×
1337

1338
                        // Deduplicate channels and collect IDs.
1339
                        var (
×
1340
                                allChannelIDs []int64
×
1341
                                allPolicyIDs  []int64
×
1342
                        )
×
1343
                        uniqueChannels := make(
×
1344
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1345
                        )
×
1346

×
1347
                        for _, channel := range allChannels {
×
1348
                                channelID := channel.GraphChannel.ID
×
1349

×
1350
                                // Only process each unique channel once.
×
1351
                                _, exists := uniqueChannels[channelID]
×
1352
                                if exists {
×
1353
                                        continue
×
1354
                                }
1355

1356
                                uniqueChannels[channelID] = channel
×
1357
                                allChannelIDs = append(allChannelIDs, channelID)
×
1358

×
1359
                                if channel.Policy1ID.Valid {
×
1360
                                        allPolicyIDs = append(
×
1361
                                                allPolicyIDs,
×
1362
                                                channel.Policy1ID.Int64,
×
1363
                                        )
×
1364
                                }
×
1365
                                if channel.Policy2ID.Valid {
×
1366
                                        allPolicyIDs = append(
×
1367
                                                allPolicyIDs,
×
1368
                                                channel.Policy2ID.Int64,
×
1369
                                        )
×
1370
                                }
×
1371
                        }
1372

1373
                        // Batch load channel data for all unique channels.
1374
                        channelBatchData, err := batchLoadChannelData(
×
1375
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1376
                                allPolicyIDs,
×
1377
                        )
×
1378
                        if err != nil {
×
1379
                                return nil, fmt.Errorf("unable to batch "+
×
1380
                                        "load channel data: %w", err)
×
1381
                        }
×
1382

1383
                        // Create map of node ID to channels that involve this
1384
                        // node.
1385
                        nodeIDSet := make(map[int64]bool)
×
1386
                        for _, nodeID := range nodeIDs {
×
1387
                                nodeIDSet[nodeID] = true
×
1388
                        }
×
1389

1390
                        nodeChannelMap := make(
×
1391
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1392
                        )
×
1393
                        for _, channel := range uniqueChannels {
×
1394
                                // Add channel to both nodes if they're in our
×
1395
                                // current page.
×
1396
                                node1 := channel.GraphChannel.NodeID1
×
1397
                                if nodeIDSet[node1] {
×
1398
                                        nodeChannelMap[node1] = append(
×
1399
                                                nodeChannelMap[node1], channel,
×
1400
                                        )
×
1401
                                }
×
1402
                                node2 := channel.GraphChannel.NodeID2
×
1403
                                if nodeIDSet[node2] {
×
1404
                                        nodeChannelMap[node2] = append(
×
1405
                                                nodeChannelMap[node2], channel,
×
1406
                                        )
×
1407
                                }
×
1408
                        }
1409

1410
                        return &nodeCachedBatchData{
×
1411
                                features:      nodeFeatures,
×
1412
                                addrs:         nodeAddrs,
×
1413
                                chanBatchData: channelBatchData,
×
1414
                                chanMap:       nodeChannelMap,
×
1415
                        }, nil
×
1416
                }
1417

1418
                // processItem is used to process each node in the current page.
1419
                processItem := func(ctx context.Context,
×
1420
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1421
                        batchData *nodeCachedBatchData) error {
×
1422

×
1423
                        // Build feature vector for this node.
×
1424
                        fv := lnwire.EmptyFeatureVector()
×
1425
                        features, exists := batchData.features[nodeData.ID]
×
1426
                        if exists {
×
1427
                                for _, bit := range features {
×
1428
                                        fv.Set(lnwire.FeatureBit(bit))
×
1429
                                }
×
1430
                        }
1431

1432
                        var nodePub route.Vertex
×
1433
                        copy(nodePub[:], nodeData.PubKey)
×
1434

×
1435
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1436

×
1437
                        toNodeCallback := func() route.Vertex {
×
1438
                                return nodePub
×
1439
                        }
×
1440

1441
                        // Build cached channels map for this node.
1442
                        channels := make(map[uint64]*DirectedChannel)
×
1443
                        for _, channelRow := range nodeChannels {
×
1444
                                directedChan, err := buildDirectedChannel(
×
1445
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1446
                                        channelRow, batchData.chanBatchData, fv,
×
1447
                                        toNodeCallback,
×
1448
                                )
×
1449
                                if err != nil {
×
1450
                                        return err
×
1451
                                }
×
1452

1453
                                channels[directedChan.ChannelID] = directedChan
×
1454
                        }
1455

1456
                        addrs, err := buildNodeAddresses(
×
1457
                                batchData.addrs[nodeData.ID],
×
1458
                        )
×
1459
                        if err != nil {
×
1460
                                return fmt.Errorf("unable to build node "+
×
1461
                                        "addresses: %w", err)
×
1462
                        }
×
1463

1464
                        return cb(ctx, nodePub, addrs, channels)
×
1465
                }
1466

1467
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1468
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1469
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1470
                                return node.ID
×
1471
                        },
×
1472
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1473
                                error) {
×
1474

×
1475
                                return node.ID, nil
×
1476
                        },
×
1477
                        batchDataFunc, processItem,
1478
                )
1479
        }, reset)
1480
}
1481

1482
// ForEachChannelCacheable iterates through all the channel edges stored
1483
// within the graph and invokes the passed callback for each edge. The
1484
// callback takes two edges as since this is a directed graph, both the
1485
// in/out edges are visited. If the callback returns an error, then the
1486
// transaction is aborted and the iteration stops early.
1487
//
1488
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1489
// pointer for that particular channel edge routing policy will be
1490
// passed into the callback.
1491
//
1492
// NOTE: this method is like ForEachChannel but fetches only the data
1493
// required for the graph cache.
1494
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1495
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1496
        reset func()) error {
×
1497

×
1498
        ctx := context.TODO()
×
1499

×
1500
        handleChannel := func(_ context.Context,
×
1501
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1502

×
1503
                node1, node2, err := buildNodeVertices(
×
1504
                        row.Node1Pubkey, row.Node2Pubkey,
×
1505
                )
×
1506
                if err != nil {
×
1507
                        return err
×
1508
                }
×
1509

1510
                edge := buildCacheableChannelInfo(
×
1511
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1512
                )
×
1513

×
1514
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1515
                if err != nil {
×
1516
                        return err
×
1517
                }
×
1518

1519
                pol1, pol2, err := buildCachedChanPolicies(
×
1520
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1521
                )
×
1522
                if err != nil {
×
1523
                        return err
×
1524
                }
×
1525

1526
                return cb(edge, pol1, pol2)
×
1527
        }
1528

1529
        extractCursor := func(
×
1530
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1531

×
1532
                return row.ID
×
1533
        }
×
1534

1535
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1536
                //nolint:ll
×
1537
                queryFunc := func(ctx context.Context, lastID int64,
×
1538
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1539
                        error) {
×
1540

×
1541
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1542
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1543
                                        Version: int16(lnwire.GossipVersion1),
×
1544
                                        ID:      lastID,
×
1545
                                        Limit:   limit,
×
1546
                                },
×
1547
                        )
×
1548
                }
×
1549

1550
                return sqldb.ExecutePaginatedQuery(
×
1551
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1552
                        extractCursor, handleChannel,
×
1553
                )
×
1554
        }, reset)
1555
}
1556

1557
// ForEachChannel iterates through all the channel edges stored within the
1558
// graph and invokes the passed callback for each edge. The callback takes two
1559
// edges as since this is a directed graph, both the in/out edges are visited.
1560
// If the callback returns an error, then the transaction is aborted and the
1561
// iteration stops early.
1562
//
1563
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1564
// for that particular channel edge routing policy will be passed into the
1565
// callback.
1566
//
1567
// NOTE: part of the V1Store interface.
1568
func (s *SQLStore) ForEachChannel(ctx context.Context,
1569
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1570
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1571

×
1572
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1573
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1574
        }, reset)
×
1575
}
1576

1577
// FilterChannelRange returns the channel ID's of all known channels which were
1578
// mined in a block height within the passed range. The channel IDs are grouped
1579
// by their common block height. This method can be used to quickly share with a
1580
// peer the set of channels we know of within a particular range to catch them
1581
// up after a period of time offline. If withTimestamps is true then the
1582
// timestamp info of the latest received channel update messages of the channel
1583
// will be included in the response.
1584
//
1585
// NOTE: This is part of the V1Store interface.
1586
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1587
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1588

×
1589
        var (
×
1590
                ctx       = context.TODO()
×
1591
                startSCID = &lnwire.ShortChannelID{
×
1592
                        BlockHeight: startHeight,
×
1593
                }
×
1594
                endSCID = lnwire.ShortChannelID{
×
1595
                        BlockHeight: endHeight,
×
1596
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1597
                        TxPosition:  math.MaxUint16,
×
1598
                }
×
1599
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1600
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1601
        )
×
1602

×
1603
        // 1) get all channels where channelID is between start and end chan ID.
×
1604
        // 2) skip if not public (ie, no channel_proof)
×
1605
        // 3) collect that channel.
×
1606
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1607
        //    and add those timestamps to the collected channel.
×
1608
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1609
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1610
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1611
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1612
                                StartScid: chanIDStart,
×
1613
                                EndScid:   chanIDEnd,
×
1614
                        },
×
1615
                )
×
1616
                if err != nil {
×
1617
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1618
                                err)
×
1619
                }
×
1620

1621
                for _, dbChan := range dbChans {
×
1622
                        cid := lnwire.NewShortChanIDFromInt(
×
1623
                                byteOrder.Uint64(dbChan.Scid),
×
1624
                        )
×
1625
                        chanInfo := NewChannelUpdateInfo(
×
1626
                                cid, time.Time{}, time.Time{},
×
1627
                        )
×
1628

×
1629
                        if !withTimestamps {
×
1630
                                channelsPerBlock[cid.BlockHeight] = append(
×
1631
                                        channelsPerBlock[cid.BlockHeight],
×
1632
                                        chanInfo,
×
1633
                                )
×
1634

×
1635
                                continue
×
1636
                        }
1637

1638
                        //nolint:ll
1639
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1640
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1641
                                        Version:   int16(lnwire.GossipVersion1),
×
1642
                                        ChannelID: dbChan.ID,
×
1643
                                        NodeID:    dbChan.NodeID1,
×
1644
                                },
×
1645
                        )
×
1646
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1647
                                return fmt.Errorf("unable to fetch node1 "+
×
1648
                                        "policy: %w", err)
×
1649
                        } else if err == nil {
×
1650
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1651
                                        node1Policy.LastUpdate.Int64, 0,
×
1652
                                )
×
1653
                        }
×
1654

1655
                        //nolint:ll
1656
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1657
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1658
                                        Version:   int16(lnwire.GossipVersion1),
×
1659
                                        ChannelID: dbChan.ID,
×
1660
                                        NodeID:    dbChan.NodeID2,
×
1661
                                },
×
1662
                        )
×
1663
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1664
                                return fmt.Errorf("unable to fetch node2 "+
×
1665
                                        "policy: %w", err)
×
1666
                        } else if err == nil {
×
1667
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1668
                                        node2Policy.LastUpdate.Int64, 0,
×
1669
                                )
×
1670
                        }
×
1671

1672
                        channelsPerBlock[cid.BlockHeight] = append(
×
1673
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1674
                        )
×
1675
                }
1676

1677
                return nil
×
1678
        }, func() {
×
1679
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1680
        })
×
1681
        if err != nil {
×
1682
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1683
        }
×
1684

1685
        if len(channelsPerBlock) == 0 {
×
1686
                return nil, nil
×
1687
        }
×
1688

1689
        // Return the channel ranges in ascending block height order.
1690
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1691
        slices.Sort(blocks)
×
1692

×
1693
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1694
                return BlockChannelRange{
×
1695
                        Height:   block,
×
1696
                        Channels: channelsPerBlock[block],
×
1697
                }
×
1698
        }), nil
×
1699
}
1700

1701
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1702
// zombie. This method is used on an ad-hoc basis, when channels need to be
1703
// marked as zombies outside the normal pruning cycle.
1704
//
1705
// NOTE: part of the V1Store interface.
1706
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1707
        pubKey1, pubKey2 [33]byte) error {
×
1708

×
1709
        ctx := context.TODO()
×
1710

×
1711
        s.cacheMu.Lock()
×
1712
        defer s.cacheMu.Unlock()
×
1713

×
1714
        chanIDB := channelIDToBytes(chanID)
×
1715

×
1716
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1717
                return db.UpsertZombieChannel(
×
1718
                        ctx, sqlc.UpsertZombieChannelParams{
×
1719
                                Version:  int16(lnwire.GossipVersion1),
×
1720
                                Scid:     chanIDB,
×
1721
                                NodeKey1: pubKey1[:],
×
1722
                                NodeKey2: pubKey2[:],
×
1723
                        },
×
1724
                )
×
1725
        }, sqldb.NoOpReset)
×
1726
        if err != nil {
×
1727
                return fmt.Errorf("unable to upsert zombie channel "+
×
1728
                        "(channel_id=%d): %w", chanID, err)
×
1729
        }
×
1730

1731
        s.rejectCache.remove(chanID)
×
1732
        s.chanCache.remove(chanID)
×
1733

×
1734
        return nil
×
1735
}
1736

1737
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1738
//
1739
// NOTE: part of the V1Store interface.
1740
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1741
        s.cacheMu.Lock()
×
1742
        defer s.cacheMu.Unlock()
×
1743

×
1744
        var (
×
1745
                ctx     = context.TODO()
×
1746
                chanIDB = channelIDToBytes(chanID)
×
1747
        )
×
1748

×
1749
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1750
                res, err := db.DeleteZombieChannel(
×
1751
                        ctx, sqlc.DeleteZombieChannelParams{
×
1752
                                Scid:    chanIDB,
×
1753
                                Version: int16(lnwire.GossipVersion1),
×
1754
                        },
×
1755
                )
×
1756
                if err != nil {
×
1757
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1758
                                err)
×
1759
                }
×
1760

1761
                rows, err := res.RowsAffected()
×
1762
                if err != nil {
×
1763
                        return err
×
1764
                }
×
1765

1766
                if rows == 0 {
×
1767
                        return ErrZombieEdgeNotFound
×
1768
                } else if rows > 1 {
×
1769
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1770
                                "expected 1", rows)
×
1771
                }
×
1772

1773
                return nil
×
1774
        }, sqldb.NoOpReset)
1775
        if err != nil {
×
1776
                return fmt.Errorf("unable to mark edge live "+
×
1777
                        "(channel_id=%d): %w", chanID, err)
×
1778
        }
×
1779

1780
        s.rejectCache.remove(chanID)
×
1781
        s.chanCache.remove(chanID)
×
1782

×
1783
        return err
×
1784
}
1785

1786
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1787
// zombie, then the two node public keys corresponding to this edge are also
1788
// returned.
1789
//
1790
// NOTE: part of the V1Store interface.
1791
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1792
        error) {
×
1793

×
1794
        var (
×
1795
                ctx              = context.TODO()
×
1796
                isZombie         bool
×
1797
                pubKey1, pubKey2 route.Vertex
×
1798
                chanIDB          = channelIDToBytes(chanID)
×
1799
        )
×
1800

×
1801
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1802
                zombie, err := db.GetZombieChannel(
×
1803
                        ctx, sqlc.GetZombieChannelParams{
×
1804
                                Scid:    chanIDB,
×
1805
                                Version: int16(lnwire.GossipVersion1),
×
1806
                        },
×
1807
                )
×
1808
                if errors.Is(err, sql.ErrNoRows) {
×
1809
                        return nil
×
1810
                }
×
1811
                if err != nil {
×
1812
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1813
                                err)
×
1814
                }
×
1815

1816
                copy(pubKey1[:], zombie.NodeKey1)
×
1817
                copy(pubKey2[:], zombie.NodeKey2)
×
1818
                isZombie = true
×
1819

×
1820
                return nil
×
1821
        }, sqldb.NoOpReset)
1822
        if err != nil {
×
1823
                return false, route.Vertex{}, route.Vertex{},
×
1824
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1825
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1826
        }
×
1827

1828
        return isZombie, pubKey1, pubKey2, nil
×
1829
}
1830

1831
// NumZombies returns the current number of zombie channels in the graph.
1832
//
1833
// NOTE: part of the V1Store interface.
1834
func (s *SQLStore) NumZombies() (uint64, error) {
×
1835
        var (
×
1836
                ctx        = context.TODO()
×
1837
                numZombies uint64
×
1838
        )
×
1839
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1840
                count, err := db.CountZombieChannels(
×
1841
                        ctx, int16(lnwire.GossipVersion1),
×
1842
                )
×
1843
                if err != nil {
×
1844
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1845
                                err)
×
1846
                }
×
1847

1848
                numZombies = uint64(count)
×
1849

×
1850
                return nil
×
1851
        }, sqldb.NoOpReset)
1852
        if err != nil {
×
1853
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1854
        }
×
1855

1856
        return numZombies, nil
×
1857
}
1858

1859
// DeleteChannelEdges removes edges with the given channel IDs from the
1860
// database and marks them as zombies. This ensures that we're unable to re-add
1861
// it to our database once again. If an edge does not exist within the
1862
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1863
// true, then when we mark these edges as zombies, we'll set up the keys such
1864
// that we require the node that failed to send the fresh update to be the one
1865
// that resurrects the channel from its zombie state. The markZombie bool
1866
// denotes whether to mark the channel as a zombie.
1867
//
1868
// NOTE: part of the V1Store interface.
1869
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1870
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1871

×
1872
        s.cacheMu.Lock()
×
1873
        defer s.cacheMu.Unlock()
×
1874

×
1875
        // Keep track of which channels we end up finding so that we can
×
1876
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1877
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1878
        for _, chanID := range chanIDs {
×
1879
                chanLookup[chanID] = struct{}{}
×
1880
        }
×
1881

1882
        var (
×
1883
                ctx   = context.TODO()
×
1884
                edges []*models.ChannelEdgeInfo
×
1885
        )
×
1886
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1887
                // First, collect all channel rows.
×
1888
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1889
                chanCallBack := func(ctx context.Context,
×
1890
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1891

×
1892
                        // Deleting the entry from the map indicates that we
×
1893
                        // have found the channel.
×
1894
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1895
                        delete(chanLookup, scid)
×
1896

×
1897
                        channelRows = append(channelRows, row)
×
1898

×
1899
                        return nil
×
1900
                }
×
1901

1902
                err := s.forEachChanWithPoliciesInSCIDList(
×
1903
                        ctx, db, chanCallBack, chanIDs,
×
1904
                )
×
1905
                if err != nil {
×
1906
                        return err
×
1907
                }
×
1908

1909
                if len(chanLookup) > 0 {
×
1910
                        return ErrEdgeNotFound
×
1911
                }
×
1912

1913
                if len(channelRows) == 0 {
×
1914
                        return nil
×
1915
                }
×
1916

1917
                // Batch build all channel edges.
1918
                var chanIDsToDelete []int64
×
1919
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1920
                        ctx, s.cfg, db, channelRows,
×
1921
                )
×
1922
                if err != nil {
×
1923
                        return err
×
1924
                }
×
1925

1926
                if markZombie {
×
1927
                        for i, row := range channelRows {
×
1928
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1929

×
1930
                                err := handleZombieMarking(
×
1931
                                        ctx, db, row, edges[i],
×
1932
                                        strictZombiePruning, scid,
×
1933
                                )
×
1934
                                if err != nil {
×
1935
                                        return fmt.Errorf("unable to mark "+
×
1936
                                                "channel as zombie: %w", err)
×
1937
                                }
×
1938
                        }
1939
                }
1940

1941
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1942
        }, func() {
×
1943
                edges = nil
×
1944

×
1945
                // Re-fill the lookup map.
×
1946
                for _, chanID := range chanIDs {
×
1947
                        chanLookup[chanID] = struct{}{}
×
1948
                }
×
1949
        })
1950
        if err != nil {
×
1951
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1952
                        err)
×
1953
        }
×
1954

1955
        for _, chanID := range chanIDs {
×
1956
                s.rejectCache.remove(chanID)
×
1957
                s.chanCache.remove(chanID)
×
1958
        }
×
1959

1960
        return edges, nil
×
1961
}
1962

1963
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1964
// channel identified by the channel ID. If the channel can't be found, then
1965
// ErrEdgeNotFound is returned. A struct which houses the general information
1966
// for the channel itself is returned as well as two structs that contain the
1967
// routing policies for the channel in either direction.
1968
//
1969
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1970
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1971
// the ChannelEdgeInfo will only include the public keys of each node.
1972
//
1973
// NOTE: part of the V1Store interface.
1974
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1975
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1976
        *models.ChannelEdgePolicy, error) {
×
1977

×
1978
        var (
×
1979
                ctx              = context.TODO()
×
1980
                edge             *models.ChannelEdgeInfo
×
1981
                policy1, policy2 *models.ChannelEdgePolicy
×
1982
                chanIDB          = channelIDToBytes(chanID)
×
1983
        )
×
1984
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1985
                row, err := db.GetChannelBySCIDWithPolicies(
×
1986
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1987
                                Scid:    chanIDB,
×
1988
                                Version: int16(lnwire.GossipVersion1),
×
1989
                        },
×
1990
                )
×
1991
                if errors.Is(err, sql.ErrNoRows) {
×
1992
                        // First check if this edge is perhaps in the zombie
×
1993
                        // index.
×
1994
                        zombie, err := db.GetZombieChannel(
×
1995
                                ctx, sqlc.GetZombieChannelParams{
×
1996
                                        Scid:    chanIDB,
×
1997
                                        Version: int16(lnwire.GossipVersion1),
×
1998
                                },
×
1999
                        )
×
2000
                        if errors.Is(err, sql.ErrNoRows) {
×
2001
                                return ErrEdgeNotFound
×
2002
                        } else if err != nil {
×
2003
                                return fmt.Errorf("unable to check if "+
×
2004
                                        "channel is zombie: %w", err)
×
2005
                        }
×
2006

2007
                        // At this point, we know the channel is a zombie, so
2008
                        // we'll return an error indicating this, and we will
2009
                        // populate the edge info with the public keys of each
2010
                        // party as this is the only information we have about
2011
                        // it.
2012
                        edge = &models.ChannelEdgeInfo{}
×
2013
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
2014
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
2015

×
2016
                        return ErrZombieEdge
×
2017
                } else if err != nil {
×
2018
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2019
                }
×
2020

2021
                node1, node2, err := buildNodeVertices(
×
2022
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2023
                )
×
2024
                if err != nil {
×
2025
                        return err
×
2026
                }
×
2027

2028
                edge, err = getAndBuildEdgeInfo(
×
2029
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2030
                )
×
2031
                if err != nil {
×
2032
                        return fmt.Errorf("unable to build channel info: %w",
×
2033
                                err)
×
2034
                }
×
2035

2036
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2037
                if err != nil {
×
2038
                        return fmt.Errorf("unable to extract channel "+
×
2039
                                "policies: %w", err)
×
2040
                }
×
2041

2042
                policy1, policy2, err = getAndBuildChanPolicies(
×
2043
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2044
                        node1, node2,
×
2045
                )
×
2046
                if err != nil {
×
2047
                        return fmt.Errorf("unable to build channel "+
×
2048
                                "policies: %w", err)
×
2049
                }
×
2050

2051
                return nil
×
2052
        }, sqldb.NoOpReset)
2053
        if err != nil {
×
2054
                // If we are returning the ErrZombieEdge, then we also need to
×
2055
                // return the edge info as the method comment indicates that
×
2056
                // this will be populated when the edge is a zombie.
×
2057
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2058
                        err)
×
2059
        }
×
2060

2061
        return edge, policy1, policy2, nil
×
2062
}
2063

2064
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2065
// the channel identified by the funding outpoint. If the channel can't be
2066
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2067
// information for the channel itself is returned as well as two structs that
2068
// contain the routing policies for the channel in either direction.
2069
//
2070
// NOTE: part of the V1Store interface.
2071
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2072
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2073
        *models.ChannelEdgePolicy, error) {
×
2074

×
2075
        var (
×
2076
                ctx              = context.TODO()
×
2077
                edge             *models.ChannelEdgeInfo
×
2078
                policy1, policy2 *models.ChannelEdgePolicy
×
2079
        )
×
2080
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2081
                row, err := db.GetChannelByOutpointWithPolicies(
×
2082
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2083
                                Outpoint: op.String(),
×
2084
                                Version:  int16(lnwire.GossipVersion1),
×
2085
                        },
×
2086
                )
×
2087
                if errors.Is(err, sql.ErrNoRows) {
×
2088
                        return ErrEdgeNotFound
×
2089
                } else if err != nil {
×
2090
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2091
                }
×
2092

2093
                node1, node2, err := buildNodeVertices(
×
2094
                        row.Node1Pubkey, row.Node2Pubkey,
×
2095
                )
×
2096
                if err != nil {
×
2097
                        return err
×
2098
                }
×
2099

2100
                edge, err = getAndBuildEdgeInfo(
×
2101
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2102
                )
×
2103
                if err != nil {
×
2104
                        return fmt.Errorf("unable to build channel info: %w",
×
2105
                                err)
×
2106
                }
×
2107

2108
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2109
                if err != nil {
×
2110
                        return fmt.Errorf("unable to extract channel "+
×
2111
                                "policies: %w", err)
×
2112
                }
×
2113

2114
                policy1, policy2, err = getAndBuildChanPolicies(
×
2115
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2116
                        node1, node2,
×
2117
                )
×
2118
                if err != nil {
×
2119
                        return fmt.Errorf("unable to build channel "+
×
2120
                                "policies: %w", err)
×
2121
                }
×
2122

2123
                return nil
×
2124
        }, sqldb.NoOpReset)
2125
        if err != nil {
×
2126
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2127
                        err)
×
2128
        }
×
2129

2130
        return edge, policy1, policy2, nil
×
2131
}
2132

2133
// HasChannelEdge returns true if the database knows of a channel edge with the
2134
// passed channel ID, and false otherwise. If an edge with that ID is found
2135
// within the graph, then two time stamps representing the last time the edge
2136
// was updated for both directed edges are returned along with the boolean. If
2137
// it is not found, then the zombie index is checked and its result is returned
2138
// as the second boolean.
2139
//
2140
// NOTE: part of the V1Store interface.
2141
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2142
        bool, error) {
×
2143

×
2144
        ctx := context.TODO()
×
2145

×
2146
        var (
×
2147
                exists          bool
×
2148
                isZombie        bool
×
2149
                node1LastUpdate time.Time
×
2150
                node2LastUpdate time.Time
×
2151
        )
×
2152

×
2153
        // We'll query the cache with the shared lock held to allow multiple
×
2154
        // readers to access values in the cache concurrently if they exist.
×
2155
        s.cacheMu.RLock()
×
2156
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2157
                s.cacheMu.RUnlock()
×
2158
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2159
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2160
                exists, isZombie = entry.flags.unpack()
×
2161

×
2162
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2163
        }
×
2164
        s.cacheMu.RUnlock()
×
2165

×
2166
        s.cacheMu.Lock()
×
2167
        defer s.cacheMu.Unlock()
×
2168

×
2169
        // The item was not found with the shared lock, so we'll acquire the
×
2170
        // exclusive lock and check the cache again in case another method added
×
2171
        // the entry to the cache while no lock was held.
×
2172
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2173
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2174
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2175
                exists, isZombie = entry.flags.unpack()
×
2176

×
2177
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2178
        }
×
2179

2180
        chanIDB := channelIDToBytes(chanID)
×
2181
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2182
                channel, err := db.GetChannelBySCID(
×
2183
                        ctx, sqlc.GetChannelBySCIDParams{
×
2184
                                Scid:    chanIDB,
×
2185
                                Version: int16(lnwire.GossipVersion1),
×
2186
                        },
×
2187
                )
×
2188
                if errors.Is(err, sql.ErrNoRows) {
×
2189
                        // Check if it is a zombie channel.
×
2190
                        isZombie, err = db.IsZombieChannel(
×
2191
                                ctx, sqlc.IsZombieChannelParams{
×
2192
                                        Scid:    chanIDB,
×
2193
                                        Version: int16(lnwire.GossipVersion1),
×
2194
                                },
×
2195
                        )
×
2196
                        if err != nil {
×
2197
                                return fmt.Errorf("could not check if channel "+
×
2198
                                        "is zombie: %w", err)
×
2199
                        }
×
2200

2201
                        return nil
×
2202
                } else if err != nil {
×
2203
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2204
                }
×
2205

2206
                exists = true
×
2207

×
2208
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2209
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2210
                                Version:   int16(lnwire.GossipVersion1),
×
2211
                                ChannelID: channel.ID,
×
2212
                                NodeID:    channel.NodeID1,
×
2213
                        },
×
2214
                )
×
2215
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2216
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2217
                                err)
×
2218
                } else if err == nil {
×
2219
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2220
                }
×
2221

2222
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2223
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2224
                                Version:   int16(lnwire.GossipVersion1),
×
2225
                                ChannelID: channel.ID,
×
2226
                                NodeID:    channel.NodeID2,
×
2227
                        },
×
2228
                )
×
2229
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2230
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2231
                                err)
×
2232
                } else if err == nil {
×
2233
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2234
                }
×
2235

2236
                return nil
×
2237
        }, sqldb.NoOpReset)
2238
        if err != nil {
×
2239
                return time.Time{}, time.Time{}, false, false,
×
2240
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2241
        }
×
2242

2243
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2244
                upd1Time: node1LastUpdate.Unix(),
×
2245
                upd2Time: node2LastUpdate.Unix(),
×
2246
                flags:    packRejectFlags(exists, isZombie),
×
2247
        })
×
2248

×
2249
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2250
}
2251

2252
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2253
// passed channel point (outpoint). If the passed channel doesn't exist within
2254
// the database, then ErrEdgeNotFound is returned.
2255
//
2256
// NOTE: part of the V1Store interface.
2257
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2258
        var (
×
2259
                ctx       = context.TODO()
×
2260
                channelID uint64
×
2261
        )
×
2262
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2263
                chanID, err := db.GetSCIDByOutpoint(
×
2264
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2265
                                Outpoint: chanPoint.String(),
×
2266
                                Version:  int16(lnwire.GossipVersion1),
×
2267
                        },
×
2268
                )
×
2269
                if errors.Is(err, sql.ErrNoRows) {
×
2270
                        return ErrEdgeNotFound
×
2271
                } else if err != nil {
×
2272
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2273
                                err)
×
2274
                }
×
2275

2276
                channelID = byteOrder.Uint64(chanID)
×
2277

×
2278
                return nil
×
2279
        }, sqldb.NoOpReset)
2280
        if err != nil {
×
2281
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2282
        }
×
2283

2284
        return channelID, nil
×
2285
}
2286

2287
// IsPublicNode is a helper method that determines whether the node with the
2288
// given public key is seen as a public node in the graph from the graph's
2289
// source node's point of view.
2290
//
2291
// NOTE: part of the V1Store interface.
2292
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2293
        ctx := context.TODO()
×
2294

×
2295
        var isPublic bool
×
2296
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2297
                var err error
×
2298
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2299

×
2300
                return err
×
2301
        }, sqldb.NoOpReset)
×
2302
        if err != nil {
×
2303
                return false, fmt.Errorf("unable to check if node is "+
×
2304
                        "public: %w", err)
×
2305
        }
×
2306

2307
        return isPublic, nil
×
2308
}
2309

2310
// FetchChanInfos returns the set of channel edges that correspond to the passed
2311
// channel ID's. If an edge is the query is unknown to the database, it will
2312
// skipped and the result will contain only those edges that exist at the time
2313
// of the query. This can be used to respond to peer queries that are seeking to
2314
// fill in gaps in their view of the channel graph.
2315
//
2316
// NOTE: part of the V1Store interface.
2317
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2318
        var (
×
2319
                ctx   = context.TODO()
×
2320
                edges = make(map[uint64]ChannelEdge)
×
2321
        )
×
2322
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2323
                // First, collect all channel rows.
×
2324
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2325
                chanCallBack := func(ctx context.Context,
×
2326
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2327

×
2328
                        channelRows = append(channelRows, row)
×
2329
                        return nil
×
2330
                }
×
2331

2332
                err := s.forEachChanWithPoliciesInSCIDList(
×
2333
                        ctx, db, chanCallBack, chanIDs,
×
2334
                )
×
2335
                if err != nil {
×
2336
                        return err
×
2337
                }
×
2338

2339
                if len(channelRows) == 0 {
×
2340
                        return nil
×
2341
                }
×
2342

2343
                // Batch build all channel edges.
2344
                chans, err := batchBuildChannelEdges(
×
2345
                        ctx, s.cfg, db, channelRows,
×
2346
                )
×
2347
                if err != nil {
×
2348
                        return fmt.Errorf("unable to build channel edges: %w",
×
2349
                                err)
×
2350
                }
×
2351

2352
                for _, c := range chans {
×
2353
                        edges[c.Info.ChannelID] = c
×
2354
                }
×
2355

2356
                return err
×
2357
        }, func() {
×
2358
                clear(edges)
×
2359
        })
×
2360
        if err != nil {
×
2361
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2362
        }
×
2363

2364
        res := make([]ChannelEdge, 0, len(edges))
×
2365
        for _, chanID := range chanIDs {
×
2366
                edge, ok := edges[chanID]
×
2367
                if !ok {
×
2368
                        continue
×
2369
                }
2370

2371
                res = append(res, edge)
×
2372
        }
2373

2374
        return res, nil
×
2375
}
2376

2377
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2378
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2379
// channels in a paginated manner.
2380
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2381
        db SQLQueries, cb func(ctx context.Context,
2382
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2383
        chanIDs []uint64) error {
×
2384

×
2385
        queryWrapper := func(ctx context.Context,
×
2386
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2387
                error) {
×
2388

×
2389
                return db.GetChannelsBySCIDWithPolicies(
×
2390
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2391
                                Version: int16(lnwire.GossipVersion1),
×
2392
                                Scids:   scids,
×
2393
                        },
×
2394
                )
×
2395
        }
×
2396

2397
        return sqldb.ExecuteBatchQuery(
×
2398
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2399
                cb,
×
2400
        )
×
2401
}
2402

2403
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2404
// ID's that we don't know and are not known zombies of the passed set. In other
2405
// words, we perform a set difference of our set of chan ID's and the ones
2406
// passed in. This method can be used by callers to determine the set of
2407
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2408
// known zombies is also returned.
2409
//
2410
// NOTE: part of the V1Store interface.
2411
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2412
        []ChannelUpdateInfo, error) {
×
2413

×
2414
        var (
×
2415
                ctx          = context.TODO()
×
2416
                newChanIDs   []uint64
×
2417
                knownZombies []ChannelUpdateInfo
×
2418
                infoLookup   = make(
×
2419
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2420
                )
×
2421
        )
×
2422

×
2423
        // We first build a lookup map of the channel ID's to the
×
2424
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2425
        // already know about.
×
2426
        for _, chanInfo := range chansInfo {
×
2427
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2428
        }
×
2429

2430
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2431
                // The call-back function deletes known channels from
×
2432
                // infoLookup, so that we can later check which channels are
×
2433
                // zombies by only looking at the remaining channels in the set.
×
2434
                cb := func(ctx context.Context,
×
2435
                        channel sqlc.GraphChannel) error {
×
2436

×
2437
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2438

×
2439
                        return nil
×
2440
                }
×
2441

2442
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2443
                if err != nil {
×
2444
                        return fmt.Errorf("unable to iterate through "+
×
2445
                                "channels: %w", err)
×
2446
                }
×
2447

2448
                // We want to ensure that we deal with the channels in the
2449
                // same order that they were passed in, so we iterate over the
2450
                // original chansInfo slice and then check if that channel is
2451
                // still in the infoLookup map.
2452
                for _, chanInfo := range chansInfo {
×
2453
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2454
                        if _, ok := infoLookup[channelID]; !ok {
×
2455
                                continue
×
2456
                        }
2457

2458
                        isZombie, err := db.IsZombieChannel(
×
2459
                                ctx, sqlc.IsZombieChannelParams{
×
2460
                                        Scid:    channelIDToBytes(channelID),
×
2461
                                        Version: int16(lnwire.GossipVersion1),
×
2462
                                },
×
2463
                        )
×
2464
                        if err != nil {
×
2465
                                return fmt.Errorf("unable to fetch zombie "+
×
2466
                                        "channel: %w", err)
×
2467
                        }
×
2468

2469
                        if isZombie {
×
2470
                                knownZombies = append(knownZombies, chanInfo)
×
2471

×
2472
                                continue
×
2473
                        }
2474

2475
                        newChanIDs = append(newChanIDs, channelID)
×
2476
                }
2477

2478
                return nil
×
2479
        }, func() {
×
2480
                newChanIDs = nil
×
2481
                knownZombies = nil
×
2482
                // Rebuild the infoLookup map in case of a rollback.
×
2483
                for _, chanInfo := range chansInfo {
×
2484
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2485
                        infoLookup[scid] = chanInfo
×
2486
                }
×
2487
        })
2488
        if err != nil {
×
2489
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2490
        }
×
2491

2492
        return newChanIDs, knownZombies, nil
×
2493
}
2494

2495
// forEachChanInSCIDList is a helper method that executes a paged query
2496
// against the database to fetch all channels that match the passed
2497
// ChannelUpdateInfo slice. The callback function is called for each channel
2498
// that is found.
2499
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2500
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2501
        chansInfo []ChannelUpdateInfo) error {
×
2502

×
2503
        queryWrapper := func(ctx context.Context,
×
2504
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2505

×
2506
                return db.GetChannelsBySCIDs(
×
2507
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2508
                                Version: int16(lnwire.GossipVersion1),
×
2509
                                Scids:   scids,
×
2510
                        },
×
2511
                )
×
2512
        }
×
2513

2514
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2515
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2516

×
2517
                return channelIDToBytes(channelID)
×
2518
        }
×
2519

2520
        return sqldb.ExecuteBatchQuery(
×
2521
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2522
                cb,
×
2523
        )
×
2524
}
2525

2526
// PruneGraphNodes is a garbage collection method which attempts to prune out
2527
// any nodes from the channel graph that are currently unconnected. This ensure
2528
// that we only maintain a graph of reachable nodes. In the event that a pruned
2529
// node gains more channels, it will be re-added back to the graph.
2530
//
2531
// NOTE: this prunes nodes across protocol versions. It will never prune the
2532
// source nodes.
2533
//
2534
// NOTE: part of the V1Store interface.
2535
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2536
        var ctx = context.TODO()
×
2537

×
2538
        var prunedNodes []route.Vertex
×
2539
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2540
                var err error
×
2541
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2542

×
2543
                return err
×
2544
        }, func() {
×
2545
                prunedNodes = nil
×
2546
        })
×
2547
        if err != nil {
×
2548
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2549
        }
×
2550

2551
        return prunedNodes, nil
×
2552
}
2553

2554
// PruneGraph prunes newly closed channels from the channel graph in response
2555
// to a new block being solved on the network. Any transactions which spend the
2556
// funding output of any known channels within he graph will be deleted.
2557
// Additionally, the "prune tip", or the last block which has been used to
2558
// prune the graph is stored so callers can ensure the graph is fully in sync
2559
// with the current UTXO state. A slice of channels that have been closed by
2560
// the target block along with any pruned nodes are returned if the function
2561
// succeeds without error.
2562
//
2563
// NOTE: part of the V1Store interface.
2564
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2565
        blockHash *chainhash.Hash, blockHeight uint32) (
2566
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2567

×
2568
        ctx := context.TODO()
×
2569

×
2570
        s.cacheMu.Lock()
×
2571
        defer s.cacheMu.Unlock()
×
2572

×
2573
        var (
×
2574
                closedChans []*models.ChannelEdgeInfo
×
2575
                prunedNodes []route.Vertex
×
2576
        )
×
2577
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2578
                // First, collect all channel rows that need to be pruned.
×
2579
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2580
                channelCallback := func(ctx context.Context,
×
2581
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2582

×
2583
                        channelRows = append(channelRows, row)
×
2584

×
2585
                        return nil
×
2586
                }
×
2587

2588
                err := s.forEachChanInOutpoints(
×
2589
                        ctx, db, spentOutputs, channelCallback,
×
2590
                )
×
2591
                if err != nil {
×
2592
                        return fmt.Errorf("unable to fetch channels by "+
×
2593
                                "outpoints: %w", err)
×
2594
                }
×
2595

2596
                if len(channelRows) == 0 {
×
2597
                        // There are no channels to prune. So we can exit early
×
2598
                        // after updating the prune log.
×
2599
                        err = db.UpsertPruneLogEntry(
×
2600
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2601
                                        BlockHash:   blockHash[:],
×
2602
                                        BlockHeight: int64(blockHeight),
×
2603
                                },
×
2604
                        )
×
2605
                        if err != nil {
×
2606
                                return fmt.Errorf("unable to insert prune log "+
×
2607
                                        "entry: %w", err)
×
2608
                        }
×
2609

2610
                        return nil
×
2611
                }
2612

2613
                // Batch build all channel edges for pruning.
2614
                var chansToDelete []int64
×
2615
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2616
                        ctx, s.cfg, db, channelRows,
×
2617
                )
×
2618
                if err != nil {
×
2619
                        return err
×
2620
                }
×
2621

2622
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2623
                if err != nil {
×
2624
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2625
                }
×
2626

2627
                err = db.UpsertPruneLogEntry(
×
2628
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2629
                                BlockHash:   blockHash[:],
×
2630
                                BlockHeight: int64(blockHeight),
×
2631
                        },
×
2632
                )
×
2633
                if err != nil {
×
2634
                        return fmt.Errorf("unable to insert prune log "+
×
2635
                                "entry: %w", err)
×
2636
                }
×
2637

2638
                // Now that we've pruned some channels, we'll also prune any
2639
                // nodes that no longer have any channels.
2640
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2641
                if err != nil {
×
2642
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2643
                                err)
×
2644
                }
×
2645

2646
                return nil
×
2647
        }, func() {
×
2648
                prunedNodes = nil
×
2649
                closedChans = nil
×
2650
        })
×
2651
        if err != nil {
×
2652
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2653
        }
×
2654

2655
        for _, channel := range closedChans {
×
2656
                s.rejectCache.remove(channel.ChannelID)
×
2657
                s.chanCache.remove(channel.ChannelID)
×
2658
        }
×
2659

2660
        return closedChans, prunedNodes, nil
×
2661
}
2662

2663
// forEachChanInOutpoints is a helper function that executes a paginated
2664
// query to fetch channels by their outpoints and applies the given call-back
2665
// to each.
2666
//
2667
// NOTE: this fetches channels for all protocol versions.
2668
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2669
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2670
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2671

×
2672
        // Create a wrapper that uses the transaction's db instance to execute
×
2673
        // the query.
×
2674
        queryWrapper := func(ctx context.Context,
×
2675
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2676
                error) {
×
2677

×
2678
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2679
        }
×
2680

2681
        // Define the conversion function from Outpoint to string.
2682
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2683
                return outpoint.String()
×
2684
        }
×
2685

2686
        return sqldb.ExecuteBatchQuery(
×
2687
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2688
                queryWrapper, cb,
×
2689
        )
×
2690
}
2691

2692
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2693
        dbIDs []int64) error {
×
2694

×
2695
        // Create a wrapper that uses the transaction's db instance to execute
×
2696
        // the query.
×
2697
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2698
                return nil, db.DeleteChannels(ctx, ids)
×
2699
        }
×
2700

2701
        idConverter := func(id int64) int64 {
×
2702
                return id
×
2703
        }
×
2704

2705
        return sqldb.ExecuteBatchQuery(
×
2706
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2707
                queryWrapper, func(ctx context.Context, _ any) error {
×
2708
                        return nil
×
2709
                },
×
2710
        )
2711
}
2712

2713
// ChannelView returns the verifiable edge information for each active channel
2714
// within the known channel graph. The set of UTXOs (along with their scripts)
2715
// returned are the ones that need to be watched on chain to detect channel
2716
// closes on the resident blockchain.
2717
//
2718
// NOTE: part of the V1Store interface.
2719
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2720
        var (
×
2721
                ctx        = context.TODO()
×
2722
                edgePoints []EdgePoint
×
2723
        )
×
2724

×
2725
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2726
                handleChannel := func(_ context.Context,
×
2727
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2728

×
2729
                        pkScript, err := genMultiSigP2WSH(
×
2730
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2731
                        )
×
2732
                        if err != nil {
×
2733
                                return err
×
2734
                        }
×
2735

2736
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2737
                        if err != nil {
×
2738
                                return err
×
2739
                        }
×
2740

2741
                        edgePoints = append(edgePoints, EdgePoint{
×
2742
                                FundingPkScript: pkScript,
×
2743
                                OutPoint:        *op,
×
2744
                        })
×
2745

×
2746
                        return nil
×
2747
                }
2748

2749
                queryFunc := func(ctx context.Context, lastID int64,
×
2750
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2751

×
2752
                        return db.ListChannelsPaginated(
×
2753
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2754
                                        Version: int16(lnwire.GossipVersion1),
×
2755
                                        ID:      lastID,
×
2756
                                        Limit:   limit,
×
2757
                                },
×
2758
                        )
×
2759
                }
×
2760

2761
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2762
                        return row.ID
×
2763
                }
×
2764

2765
                return sqldb.ExecutePaginatedQuery(
×
2766
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2767
                        extractCursor, handleChannel,
×
2768
                )
×
2769
        }, func() {
×
2770
                edgePoints = nil
×
2771
        })
×
2772
        if err != nil {
×
2773
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2774
        }
×
2775

2776
        return edgePoints, nil
×
2777
}
2778

2779
// PruneTip returns the block height and hash of the latest block that has been
2780
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2781
// to tell if the graph is currently in sync with the current best known UTXO
2782
// state.
2783
//
2784
// NOTE: part of the V1Store interface.
2785
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2786
        var (
×
2787
                ctx       = context.TODO()
×
2788
                tipHash   chainhash.Hash
×
2789
                tipHeight uint32
×
2790
        )
×
2791
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2792
                pruneTip, err := db.GetPruneTip(ctx)
×
2793
                if errors.Is(err, sql.ErrNoRows) {
×
2794
                        return ErrGraphNeverPruned
×
2795
                } else if err != nil {
×
2796
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2797
                }
×
2798

2799
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2800
                tipHeight = uint32(pruneTip.BlockHeight)
×
2801

×
2802
                return nil
×
2803
        }, sqldb.NoOpReset)
2804
        if err != nil {
×
2805
                return nil, 0, err
×
2806
        }
×
2807

2808
        return &tipHash, tipHeight, nil
×
2809
}
2810

2811
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2812
//
2813
// NOTE: this prunes nodes across protocol versions. It will never prune the
2814
// source nodes.
2815
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2816
        db SQLQueries) ([]route.Vertex, error) {
×
2817

×
2818
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2819
        if err != nil {
×
2820
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2821
                        "nodes: %w", err)
×
2822
        }
×
2823

2824
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2825
        for i, nodeKey := range nodeKeys {
×
2826
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2827
                if err != nil {
×
2828
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2829
                                "from bytes: %w", err)
×
2830
                }
×
2831

2832
                prunedNodes[i] = pub
×
2833
        }
2834

2835
        return prunedNodes, nil
×
2836
}
2837

2838
// DisconnectBlockAtHeight is used to indicate that the block specified
2839
// by the passed height has been disconnected from the main chain. This
2840
// will "rewind" the graph back to the height below, deleting channels
2841
// that are no longer confirmed from the graph. The prune log will be
2842
// set to the last prune height valid for the remaining chain.
2843
// Channels that were removed from the graph resulting from the
2844
// disconnected block are returned.
2845
//
2846
// NOTE: part of the V1Store interface.
2847
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2848
        []*models.ChannelEdgeInfo, error) {
×
2849

×
2850
        ctx := context.TODO()
×
2851

×
2852
        var (
×
2853
                // Every channel having a ShortChannelID starting at 'height'
×
2854
                // will no longer be confirmed.
×
2855
                startShortChanID = lnwire.ShortChannelID{
×
2856
                        BlockHeight: height,
×
2857
                }
×
2858

×
2859
                // Delete everything after this height from the db up until the
×
2860
                // SCID alias range.
×
2861
                endShortChanID = aliasmgr.StartingAlias
×
2862

×
2863
                removedChans []*models.ChannelEdgeInfo
×
2864

×
2865
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2866
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2867
        )
×
2868

×
2869
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2870
                rows, err := db.GetChannelsBySCIDRange(
×
2871
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2872
                                StartScid: chanIDStart,
×
2873
                                EndScid:   chanIDEnd,
×
2874
                        },
×
2875
                )
×
2876
                if err != nil {
×
2877
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2878
                }
×
2879

2880
                if len(rows) == 0 {
×
2881
                        // No channels to disconnect, but still clean up prune
×
2882
                        // log.
×
2883
                        return db.DeletePruneLogEntriesInRange(
×
2884
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2885
                                        StartHeight: int64(height),
×
2886
                                        EndHeight: int64(
×
2887
                                                endShortChanID.BlockHeight,
×
2888
                                        ),
×
2889
                                },
×
2890
                        )
×
2891
                }
×
2892

2893
                // Batch build all channel edges for disconnection.
2894
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2895
                        ctx, s.cfg, db, rows,
×
2896
                )
×
2897
                if err != nil {
×
2898
                        return err
×
2899
                }
×
2900

2901
                removedChans = channelEdges
×
2902

×
2903
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2904
                if err != nil {
×
2905
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2906
                }
×
2907

2908
                return db.DeletePruneLogEntriesInRange(
×
2909
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2910
                                StartHeight: int64(height),
×
2911
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2912
                        },
×
2913
                )
×
2914
        }, func() {
×
2915
                removedChans = nil
×
2916
        })
×
2917
        if err != nil {
×
2918
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2919
                        "height: %w", err)
×
2920
        }
×
2921

2922
        s.cacheMu.Lock()
×
2923
        for _, channel := range removedChans {
×
2924
                s.rejectCache.remove(channel.ChannelID)
×
2925
                s.chanCache.remove(channel.ChannelID)
×
2926
        }
×
2927
        s.cacheMu.Unlock()
×
2928

×
2929
        return removedChans, nil
×
2930
}
2931

2932
// AddEdgeProof sets the proof of an existing edge in the graph database.
2933
//
2934
// NOTE: part of the V1Store interface.
2935
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2936
        proof *models.ChannelAuthProof) error {
×
2937

×
2938
        var (
×
2939
                ctx       = context.TODO()
×
2940
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2941
        )
×
2942

×
2943
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2944
                res, err := db.AddV1ChannelProof(
×
2945
                        ctx, sqlc.AddV1ChannelProofParams{
×
2946
                                Scid:              scidBytes,
×
2947
                                Node1Signature:    proof.NodeSig1Bytes,
×
2948
                                Node2Signature:    proof.NodeSig2Bytes,
×
2949
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2950
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2951
                        },
×
2952
                )
×
2953
                if err != nil {
×
2954
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2955
                }
×
2956

2957
                n, err := res.RowsAffected()
×
2958
                if err != nil {
×
2959
                        return err
×
2960
                }
×
2961

2962
                if n == 0 {
×
2963
                        return fmt.Errorf("no rows affected when adding edge "+
×
2964
                                "proof for SCID %v", scid)
×
2965
                } else if n > 1 {
×
2966
                        return fmt.Errorf("multiple rows affected when adding "+
×
2967
                                "edge proof for SCID %v: %d rows affected",
×
2968
                                scid, n)
×
2969
                }
×
2970

2971
                return nil
×
2972
        }, sqldb.NoOpReset)
2973
        if err != nil {
×
2974
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2975
        }
×
2976

2977
        return nil
×
2978
}
2979

2980
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2981
// that we can ignore channel announcements that we know to be closed without
2982
// having to validate them and fetch a block.
2983
//
2984
// NOTE: part of the V1Store interface.
2985
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2986
        var (
×
2987
                ctx     = context.TODO()
×
2988
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2989
        )
×
2990

×
2991
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2992
                return db.InsertClosedChannel(ctx, chanIDB)
×
2993
        }, sqldb.NoOpReset)
×
2994
}
2995

2996
// IsClosedScid checks whether a channel identified by the passed in scid is
2997
// closed. This helps avoid having to perform expensive validation checks.
2998
//
2999
// NOTE: part of the V1Store interface.
3000
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
3001
        var (
×
3002
                ctx      = context.TODO()
×
3003
                isClosed bool
×
3004
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
3005
        )
×
3006
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3007
                var err error
×
3008
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3009
                if err != nil {
×
3010
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3011
                                err)
×
3012
                }
×
3013

3014
                return nil
×
3015
        }, sqldb.NoOpReset)
3016
        if err != nil {
×
3017
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3018
                        err)
×
3019
        }
×
3020

3021
        return isClosed, nil
×
3022
}
3023

3024
// GraphSession will provide the call-back with access to a NodeTraverser
3025
// instance which can be used to perform queries against the channel graph.
3026
//
3027
// NOTE: part of the V1Store interface.
3028
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3029
        reset func()) error {
×
3030

×
3031
        var ctx = context.TODO()
×
3032

×
3033
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3034
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3035
        }, reset)
×
3036
}
3037

3038
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3039
// read only transaction for a consistent view of the graph.
3040
type sqlNodeTraverser struct {
3041
        db    SQLQueries
3042
        chain chainhash.Hash
3043
}
3044

3045
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3046
// NodeTraverser interface.
3047
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3048

3049
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3050
func newSQLNodeTraverser(db SQLQueries,
3051
        chain chainhash.Hash) *sqlNodeTraverser {
×
3052

×
3053
        return &sqlNodeTraverser{
×
3054
                db:    db,
×
3055
                chain: chain,
×
3056
        }
×
3057
}
×
3058

3059
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3060
// node.
3061
//
3062
// NOTE: Part of the NodeTraverser interface.
3063
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3064
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3065

×
3066
        ctx := context.TODO()
×
3067

×
3068
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3069
}
×
3070

3071
// FetchNodeFeatures returns the features of the given node. If the node is
3072
// unknown, assume no additional features are supported.
3073
//
3074
// NOTE: Part of the NodeTraverser interface.
3075
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3076
        *lnwire.FeatureVector, error) {
×
3077

×
3078
        ctx := context.TODO()
×
3079

×
3080
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3081
}
×
3082

3083
// forEachNodeDirectedChannel iterates through all channels of a given
3084
// node, executing the passed callback on the directed edge representing the
3085
// channel and its incoming policy. If the node is not found, no error is
3086
// returned.
3087
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3088
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3089

×
3090
        toNodeCallback := func() route.Vertex {
×
3091
                return nodePub
×
3092
        }
×
3093

3094
        dbID, err := db.GetNodeIDByPubKey(
×
3095
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3096
                        Version: int16(lnwire.GossipVersion1),
×
3097
                        PubKey:  nodePub[:],
×
3098
                },
×
3099
        )
×
3100
        if errors.Is(err, sql.ErrNoRows) {
×
3101
                return nil
×
3102
        } else if err != nil {
×
3103
                return fmt.Errorf("unable to fetch node: %w", err)
×
3104
        }
×
3105

3106
        rows, err := db.ListChannelsByNodeID(
×
3107
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3108
                        Version: int16(lnwire.GossipVersion1),
×
3109
                        NodeID1: dbID,
×
3110
                },
×
3111
        )
×
3112
        if err != nil {
×
3113
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3114
        }
×
3115

3116
        // Exit early if there are no channels for this node so we don't
3117
        // do the unnecessary feature fetching.
3118
        if len(rows) == 0 {
×
3119
                return nil
×
3120
        }
×
3121

3122
        features, err := getNodeFeatures(ctx, db, dbID)
×
3123
        if err != nil {
×
3124
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3125
        }
×
3126

3127
        for _, row := range rows {
×
3128
                node1, node2, err := buildNodeVertices(
×
3129
                        row.Node1Pubkey, row.Node2Pubkey,
×
3130
                )
×
3131
                if err != nil {
×
3132
                        return fmt.Errorf("unable to build node vertices: %w",
×
3133
                                err)
×
3134
                }
×
3135

3136
                edge := buildCacheableChannelInfo(
×
3137
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3138
                        node1, node2,
×
3139
                )
×
3140

×
3141
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3142
                if err != nil {
×
3143
                        return err
×
3144
                }
×
3145

3146
                p1, p2, err := buildCachedChanPolicies(
×
3147
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3148
                )
×
3149
                if err != nil {
×
3150
                        return err
×
3151
                }
×
3152

3153
                // Determine the outgoing and incoming policy for this
3154
                // channel and node combo.
3155
                outPolicy, inPolicy := p1, p2
×
3156
                if p1 != nil && node2 == nodePub {
×
3157
                        outPolicy, inPolicy = p2, p1
×
3158
                } else if p2 != nil && node1 != nodePub {
×
3159
                        outPolicy, inPolicy = p2, p1
×
3160
                }
×
3161

3162
                var cachedInPolicy *models.CachedEdgePolicy
×
3163
                if inPolicy != nil {
×
3164
                        cachedInPolicy = inPolicy
×
3165
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3166
                        cachedInPolicy.ToNodeFeatures = features
×
3167
                }
×
3168

3169
                directedChannel := &DirectedChannel{
×
3170
                        ChannelID:    edge.ChannelID,
×
3171
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3172
                        OtherNode:    edge.NodeKey2Bytes,
×
3173
                        Capacity:     edge.Capacity,
×
3174
                        OutPolicySet: outPolicy != nil,
×
3175
                        InPolicy:     cachedInPolicy,
×
3176
                }
×
3177
                if outPolicy != nil {
×
3178
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3179
                                directedChannel.InboundFee = fee
×
3180
                        })
×
3181
                }
3182

3183
                if nodePub == edge.NodeKey2Bytes {
×
3184
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3185
                }
×
3186

3187
                if err := cb(directedChannel); err != nil {
×
3188
                        return err
×
3189
                }
×
3190
        }
3191

3192
        return nil
×
3193
}
3194

3195
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3196
// and executes the provided callback for each node. It does so via pagination
3197
// along with batch loading of the node feature bits.
3198
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3199
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3200
                features *lnwire.FeatureVector) error) error {
×
3201

×
3202
        handleNode := func(_ context.Context,
×
3203
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3204
                featureBits map[int64][]int) error {
×
3205

×
3206
                fv := lnwire.EmptyFeatureVector()
×
3207
                if features, exists := featureBits[dbNode.ID]; exists {
×
3208
                        for _, bit := range features {
×
3209
                                fv.Set(lnwire.FeatureBit(bit))
×
3210
                        }
×
3211
                }
3212

3213
                var pub route.Vertex
×
3214
                copy(pub[:], dbNode.PubKey)
×
3215

×
3216
                return processNode(dbNode.ID, pub, fv)
×
3217
        }
3218

3219
        queryFunc := func(ctx context.Context, lastID int64,
×
3220
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3221

×
3222
                return db.ListNodeIDsAndPubKeys(
×
3223
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3224
                                Version: int16(lnwire.GossipVersion1),
×
3225
                                ID:      lastID,
×
3226
                                Limit:   limit,
×
3227
                        },
×
3228
                )
×
3229
        }
×
3230

3231
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3232
                return row.ID
×
3233
        }
×
3234

3235
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3236
                return node.ID, nil
×
3237
        }
×
3238

3239
        batchQueryFunc := func(ctx context.Context,
×
3240
                nodeIDs []int64) (map[int64][]int, error) {
×
3241

×
3242
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3243
        }
×
3244

3245
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3246
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3247
                batchQueryFunc, handleNode,
×
3248
        )
×
3249
}
3250

3251
// forEachNodeChannel iterates through all channels of a node, executing
3252
// the passed callback on each. The call-back is provided with the channel's
3253
// edge information, the outgoing policy and the incoming policy for the
3254
// channel and node combo.
3255
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3256
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3257
                *models.ChannelEdgePolicy,
3258
                *models.ChannelEdgePolicy) error) error {
×
3259

×
3260
        // Get all the V1 channels for this node.
×
3261
        rows, err := db.ListChannelsByNodeID(
×
3262
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3263
                        Version: int16(lnwire.GossipVersion1),
×
3264
                        NodeID1: id,
×
3265
                },
×
3266
        )
×
3267
        if err != nil {
×
3268
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3269
        }
×
3270

3271
        // Collect all the channel and policy IDs.
3272
        var (
×
3273
                chanIDs   = make([]int64, 0, len(rows))
×
3274
                policyIDs = make([]int64, 0, 2*len(rows))
×
3275
        )
×
3276
        for _, row := range rows {
×
3277
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3278

×
3279
                if row.Policy1ID.Valid {
×
3280
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3281
                }
×
3282
                if row.Policy2ID.Valid {
×
3283
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3284
                }
×
3285
        }
3286

3287
        batchData, err := batchLoadChannelData(
×
3288
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3289
        )
×
3290
        if err != nil {
×
3291
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3292
        }
×
3293

3294
        // Call the call-back for each channel and its known policies.
3295
        for _, row := range rows {
×
3296
                node1, node2, err := buildNodeVertices(
×
3297
                        row.Node1Pubkey, row.Node2Pubkey,
×
3298
                )
×
3299
                if err != nil {
×
3300
                        return fmt.Errorf("unable to build node vertices: %w",
×
3301
                                err)
×
3302
                }
×
3303

3304
                edge, err := buildEdgeInfoWithBatchData(
×
3305
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3306
                        batchData,
×
3307
                )
×
3308
                if err != nil {
×
3309
                        return fmt.Errorf("unable to build channel info: %w",
×
3310
                                err)
×
3311
                }
×
3312

3313
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3314
                if err != nil {
×
3315
                        return fmt.Errorf("unable to extract channel "+
×
3316
                                "policies: %w", err)
×
3317
                }
×
3318

3319
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3320
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3321
                )
×
3322
                if err != nil {
×
3323
                        return fmt.Errorf("unable to build channel "+
×
3324
                                "policies: %w", err)
×
3325
                }
×
3326

3327
                // Determine the outgoing and incoming policy for this
3328
                // channel and node combo.
3329
                p1ToNode := row.GraphChannel.NodeID2
×
3330
                p2ToNode := row.GraphChannel.NodeID1
×
3331
                outPolicy, inPolicy := p1, p2
×
3332
                if (p1 != nil && p1ToNode == id) ||
×
3333
                        (p2 != nil && p2ToNode != id) {
×
3334

×
3335
                        outPolicy, inPolicy = p2, p1
×
3336
                }
×
3337

3338
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3339
                        return err
×
3340
                }
×
3341
        }
3342

3343
        return nil
×
3344
}
3345

3346
// updateChanEdgePolicy upserts the channel policy info we have stored for
3347
// a channel we already know of.
3348
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3349
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3350
        error) {
×
3351

×
3352
        var (
×
3353
                node1Pub, node2Pub route.Vertex
×
3354
                isNode1            bool
×
3355
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3356
        )
×
3357

×
3358
        // Check that this edge policy refers to a channel that we already
×
3359
        // know of. We do this explicitly so that we can return the appropriate
×
3360
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3361
        // abort the transaction which would abort the entire batch.
×
3362
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3363
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3364
                        Scid:    chanIDB,
×
3365
                        Version: int16(lnwire.GossipVersion1),
×
3366
                },
×
3367
        )
×
3368
        if errors.Is(err, sql.ErrNoRows) {
×
3369
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3370
        } else if err != nil {
×
3371
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3372
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3373
        }
×
3374

3375
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3376
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3377

×
3378
        // Figure out which node this edge is from.
×
3379
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3380
        nodeID := dbChan.NodeID1
×
3381
        if !isNode1 {
×
3382
                nodeID = dbChan.NodeID2
×
3383
        }
×
3384

3385
        var (
×
3386
                inboundBase sql.NullInt64
×
3387
                inboundRate sql.NullInt64
×
3388
        )
×
3389
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3390
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3391
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3392
        })
×
3393

3394
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3395
                Version:     int16(lnwire.GossipVersion1),
×
3396
                ChannelID:   dbChan.ID,
×
3397
                NodeID:      nodeID,
×
3398
                Timelock:    int32(edge.TimeLockDelta),
×
3399
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3400
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3401
                MinHtlcMsat: int64(edge.MinHTLC),
×
3402
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3403
                Disabled: sql.NullBool{
×
3404
                        Valid: true,
×
3405
                        Bool:  edge.IsDisabled(),
×
3406
                },
×
3407
                MaxHtlcMsat: sql.NullInt64{
×
3408
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3409
                        Int64: int64(edge.MaxHTLC),
×
3410
                },
×
3411
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3412
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3413
                InboundBaseFeeMsat:      inboundBase,
×
3414
                InboundFeeRateMilliMsat: inboundRate,
×
3415
                Signature:               edge.SigBytes,
×
3416
        })
×
3417
        if err != nil {
×
3418
                return node1Pub, node2Pub, isNode1,
×
3419
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3420
        }
×
3421

3422
        // Convert the flat extra opaque data into a map of TLV types to
3423
        // values.
3424
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3425
        if err != nil {
×
3426
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3427
                        "marshal extra opaque data: %w", err)
×
3428
        }
×
3429

3430
        // Update the channel policy's extra signed fields.
3431
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3432
        if err != nil {
×
3433
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3434
                        "policy extra TLVs: %w", err)
×
3435
        }
×
3436

3437
        return node1Pub, node2Pub, isNode1, nil
×
3438
}
3439

3440
// getNodeByPubKey attempts to look up a target node by its public key.
3441
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3442
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3443

×
3444
        dbNode, err := db.GetNodeByPubKey(
×
3445
                ctx, sqlc.GetNodeByPubKeyParams{
×
3446
                        Version: int16(lnwire.GossipVersion1),
×
3447
                        PubKey:  pubKey[:],
×
3448
                },
×
3449
        )
×
3450
        if errors.Is(err, sql.ErrNoRows) {
×
3451
                return 0, nil, ErrGraphNodeNotFound
×
3452
        } else if err != nil {
×
3453
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3454
        }
×
3455

3456
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3457
        if err != nil {
×
3458
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3459
        }
×
3460

3461
        return dbNode.ID, node, nil
×
3462
}
3463

3464
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3465
// provided parameters.
3466
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3467
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3468

×
3469
        return &models.CachedEdgeInfo{
×
3470
                ChannelID:     byteOrder.Uint64(scid),
×
3471
                NodeKey1Bytes: node1Pub,
×
3472
                NodeKey2Bytes: node2Pub,
×
3473
                Capacity:      btcutil.Amount(capacity),
×
3474
        }
×
3475
}
×
3476

3477
// buildNode constructs a Node instance from the given database node
3478
// record. The node's features, addresses and extra signed fields are also
3479
// fetched from the database and set on the node.
3480
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3481
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3482

×
3483
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3484
        if err != nil {
×
3485
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3486
                        err)
×
3487
        }
×
3488

3489
        return buildNodeWithBatchData(dbNode, data)
×
3490
}
3491

3492
// buildNodeWithBatchData builds a models.Node instance
3493
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3494
// features/addresses/extra fields, then the corresponding fields are expected
3495
// to be present in the batchNodeData.
3496
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3497
        batchData *batchNodeData) (*models.Node, error) {
×
3498

×
3499
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
3500
                return nil, fmt.Errorf("unsupported node version: %d",
×
3501
                        dbNode.Version)
×
3502
        }
×
3503

3504
        var pub [33]byte
×
3505
        copy(pub[:], dbNode.PubKey)
×
3506

×
3507
        node := models.NewV1ShellNode(pub)
×
3508

×
3509
        if len(dbNode.Signature) == 0 {
×
3510
                return node, nil
×
3511
        }
×
3512

3513
        node.AuthSigBytes = dbNode.Signature
×
3514

×
3515
        if dbNode.Alias.Valid {
×
3516
                node.Alias = fn.Some(dbNode.Alias.String)
×
3517
        }
×
3518
        if dbNode.LastUpdate.Valid {
×
3519
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3520
        }
×
3521

3522
        var err error
×
3523
        if dbNode.Color.Valid {
×
3524
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3525
                if err != nil {
×
3526
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3527
                                err)
×
3528
                }
×
3529

3530
                node.Color = fn.Some(nodeColor)
×
3531
        }
3532

3533
        // Use preloaded features.
3534
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3535
                fv := lnwire.EmptyFeatureVector()
×
3536
                for _, bit := range features {
×
3537
                        fv.Set(lnwire.FeatureBit(bit))
×
3538
                }
×
3539
                node.Features = fv
×
3540
        }
3541

3542
        // Use preloaded addresses.
3543
        addresses, exists := batchData.addresses[dbNode.ID]
×
3544
        if exists && len(addresses) > 0 {
×
3545
                node.Addresses, err = buildNodeAddresses(addresses)
×
3546
                if err != nil {
×
3547
                        return nil, fmt.Errorf("unable to build addresses "+
×
3548
                                "for node(%d): %w", dbNode.ID, err)
×
3549
                }
×
3550
        }
3551

3552
        // Use preloaded extra fields.
3553
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3554
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3555
                if err != nil {
×
3556
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3557
                                "signed fields: %w", err)
×
3558
                }
×
3559
                if len(recs) != 0 {
×
3560
                        node.ExtraOpaqueData = recs
×
3561
                }
×
3562
        }
3563

3564
        return node, nil
×
3565
}
3566

3567
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3568
// with the preloaded data, and executes the provided callback for each node.
3569
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3570
        db SQLQueries, nodes []sqlc.GraphNode,
3571
        cb func(dbID int64, node *models.Node) error) error {
×
3572

×
3573
        // Extract node IDs for batch loading.
×
3574
        nodeIDs := make([]int64, len(nodes))
×
3575
        for i, node := range nodes {
×
3576
                nodeIDs[i] = node.ID
×
3577
        }
×
3578

3579
        // Batch load all related data for this page.
3580
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3581
        if err != nil {
×
3582
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3583
        }
×
3584

3585
        for _, dbNode := range nodes {
×
3586
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3587
                if err != nil {
×
3588
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3589
                                dbNode.ID, err)
×
3590
                }
×
3591

3592
                if err := cb(dbNode.ID, node); err != nil {
×
3593
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3594
                                dbNode.ID, err)
×
3595
                }
×
3596
        }
3597

3598
        return nil
×
3599
}
3600

3601
// getNodeFeatures fetches the feature bits and constructs the feature vector
3602
// for a node with the given DB ID.
3603
func getNodeFeatures(ctx context.Context, db SQLQueries,
3604
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3605

×
3606
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3607
        if err != nil {
×
3608
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3609
                        nodeID, err)
×
3610
        }
×
3611

3612
        features := lnwire.EmptyFeatureVector()
×
3613
        for _, feature := range rows {
×
3614
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3615
        }
×
3616

3617
        return features, nil
×
3618
}
3619

3620
// upsertNodeAncillaryData updates the node's features, addresses, and extra
3621
// signed fields. This is common logic shared by upsertNode and
3622
// upsertSourceNode.
3623
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
3624
        nodeID int64, node *models.Node) error {
×
3625

×
3626
        // Update the node's features.
×
3627
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3628
        if err != nil {
×
3629
                return fmt.Errorf("inserting node features: %w", err)
×
3630
        }
×
3631

3632
        // Update the node's addresses.
3633
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3634
        if err != nil {
×
3635
                return fmt.Errorf("inserting node addresses: %w", err)
×
3636
        }
×
3637

3638
        // Convert the flat extra opaque data into a map of TLV types to
3639
        // values.
3640
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3641
        if err != nil {
×
3642
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3643
                        err)
×
3644
        }
×
3645

3646
        // Update the node's extra signed fields.
3647
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3648
        if err != nil {
×
3649
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
3650
        }
×
3651

3652
        return nil
×
3653
}
3654

3655
// populateNodeParams populates the common node parameters from a models.Node.
3656
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
3657
func populateNodeParams(node *models.Node,
3658
        setParams func(lastUpdate sql.NullInt64, alias,
3659
                colorStr sql.NullString, signature []byte)) error {
×
3660

×
3661
        if !node.HaveAnnouncement() {
×
3662
                return nil
×
3663
        }
×
3664

3665
        switch node.Version {
×
3666
        case lnwire.GossipVersion1:
×
3667
                lastUpdate := sqldb.SQLInt64(node.LastUpdate.Unix())
×
3668
                var alias, colorStr sql.NullString
×
3669

×
3670
                node.Color.WhenSome(func(rgba color.RGBA) {
×
3671
                        colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
3672
                })
×
3673
                node.Alias.WhenSome(func(s string) {
×
3674
                        alias = sqldb.SQLStrValid(s)
×
3675
                })
×
3676

3677
                setParams(lastUpdate, alias, colorStr, node.AuthSigBytes)
×
3678

3679
        case lnwire.GossipVersion2:
×
3680
                // No-op for now.
3681

3682
        default:
×
3683
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
3684
        }
3685

3686
        return nil
×
3687
}
3688

3689
// buildNodeUpsertParams builds the parameters for upserting a node using the
3690
// strict UpsertNode query (requires timestamp to be increasing).
3691
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
3692
        params := sqlc.UpsertNodeParams{
×
3693
                Version: int16(lnwire.GossipVersion1),
×
3694
                PubKey:  node.PubKeyBytes[:],
×
3695
        }
×
3696

×
3697
        err := populateNodeParams(
×
3698
                node, func(lastUpdate sql.NullInt64, alias,
×
3699
                        colorStr sql.NullString,
×
3700
                        signature []byte) {
×
3701

×
3702
                        params.LastUpdate = lastUpdate
×
3703
                        params.Alias = alias
×
3704
                        params.Color = colorStr
×
3705
                        params.Signature = signature
×
3706
                })
×
3707

3708
        return params, err
×
3709
}
3710

3711
// buildSourceNodeUpsertParams builds the parameters for upserting the source
3712
// node using the lenient UpsertSourceNode query (allows same timestamp).
3713
func buildSourceNodeUpsertParams(node *models.Node) (
3714
        sqlc.UpsertSourceNodeParams, error) {
×
3715

×
3716
        params := sqlc.UpsertSourceNodeParams{
×
3717
                Version: int16(lnwire.GossipVersion1),
×
3718
                PubKey:  node.PubKeyBytes[:],
×
3719
        }
×
3720

×
3721
        err := populateNodeParams(
×
3722
                node, func(lastUpdate sql.NullInt64, alias,
×
3723
                        colorStr sql.NullString, signature []byte) {
×
3724

×
3725
                        params.LastUpdate = lastUpdate
×
3726
                        params.Alias = alias
×
3727
                        params.Color = colorStr
×
3728
                        params.Signature = signature
×
3729
                },
×
3730
        )
3731

3732
        return params, err
×
3733
}
3734

3735
// upsertSourceNode upserts the source node record into the database using a
3736
// less strict upsert that allows updates even when the timestamp hasn't
3737
// changed. This is necessary to handle concurrent updates to our own node
3738
// during startup and runtime. The node's features, addresses and extra TLV
3739
// types are also updated. The node's DB ID is returned.
3740
func upsertSourceNode(ctx context.Context, db SQLQueries,
3741
        node *models.Node) (int64, error) {
×
3742

×
3743
        params, err := buildSourceNodeUpsertParams(node)
×
3744
        if err != nil {
×
3745
                return 0, err
×
3746
        }
×
3747

3748
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
3749
        if err != nil {
×
3750
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
3751
                        node.PubKeyBytes, err)
×
3752
        }
×
3753

3754
        // We can exit here if we don't have the announcement yet.
3755
        if !node.HaveAnnouncement() {
×
3756
                return nodeID, nil
×
3757
        }
×
3758

3759
        // Update the ancillary node data (features, addresses, extra fields).
3760
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3761
        if err != nil {
×
3762
                return 0, err
×
3763
        }
×
3764

3765
        return nodeID, nil
×
3766
}
3767

3768
// upsertNode upserts the node record into the database. If the node already
3769
// exists, then the node's information is updated. If the node doesn't exist,
3770
// then a new node is created. The node's features, addresses and extra TLV
3771
// types are also updated. The node's DB ID is returned.
3772
func upsertNode(ctx context.Context, db SQLQueries,
3773
        node *models.Node) (int64, error) {
×
3774

×
3775
        params, err := buildNodeUpsertParams(node)
×
3776
        if err != nil {
×
3777
                return 0, err
×
3778
        }
×
3779

3780
        nodeID, err := db.UpsertNode(ctx, params)
×
3781
        if err != nil {
×
3782
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3783
                        err)
×
3784
        }
×
3785

3786
        // We can exit here if we don't have the announcement yet.
3787
        if !node.HaveAnnouncement() {
×
3788
                return nodeID, nil
×
3789
        }
×
3790

3791
        // Update the ancillary node data (features, addresses, extra fields).
3792
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3793
        if err != nil {
×
3794
                return 0, err
×
3795
        }
×
3796

3797
        return nodeID, nil
×
3798
}
3799

3800
// upsertNodeFeatures updates the node's features node_features table. This
3801
// includes deleting any feature bits no longer present and inserting any new
3802
// feature bits. If the feature bit does not yet exist in the features table,
3803
// then an entry is created in that table first.
3804
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3805
        features *lnwire.FeatureVector) error {
×
3806

×
3807
        // Get any existing features for the node.
×
3808
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3809
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3810
                return err
×
3811
        }
×
3812

3813
        // Copy the nodes latest set of feature bits.
3814
        newFeatures := make(map[int32]struct{})
×
3815
        if features != nil {
×
3816
                for feature := range features.Features() {
×
3817
                        newFeatures[int32(feature)] = struct{}{}
×
3818
                }
×
3819
        }
3820

3821
        // For any current feature that already exists in the DB, remove it from
3822
        // the in-memory map. For any existing feature that does not exist in
3823
        // the in-memory map, delete it from the database.
3824
        for _, feature := range existingFeatures {
×
3825
                // The feature is still present, so there are no updates to be
×
3826
                // made.
×
3827
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3828
                        delete(newFeatures, feature.FeatureBit)
×
3829
                        continue
×
3830
                }
3831

3832
                // The feature is no longer present, so we remove it from the
3833
                // database.
3834
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3835
                        NodeID:     nodeID,
×
3836
                        FeatureBit: feature.FeatureBit,
×
3837
                })
×
3838
                if err != nil {
×
3839
                        return fmt.Errorf("unable to delete node(%d) "+
×
3840
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3841
                                err)
×
3842
                }
×
3843
        }
3844

3845
        // Any remaining entries in newFeatures are new features that need to be
3846
        // added to the database for the first time.
3847
        for feature := range newFeatures {
×
3848
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3849
                        NodeID:     nodeID,
×
3850
                        FeatureBit: feature,
×
3851
                })
×
3852
                if err != nil {
×
3853
                        return fmt.Errorf("unable to insert node(%d) "+
×
3854
                                "feature(%v): %w", nodeID, feature, err)
×
3855
                }
×
3856
        }
3857

3858
        return nil
×
3859
}
3860

3861
// fetchNodeFeatures fetches the features for a node with the given public key.
3862
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3863
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3864

×
3865
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3866
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3867
                        PubKey:  nodePub[:],
×
3868
                        Version: int16(lnwire.GossipVersion1),
×
3869
                },
×
3870
        )
×
3871
        if err != nil {
×
3872
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3873
                        nodePub, err)
×
3874
        }
×
3875

3876
        features := lnwire.EmptyFeatureVector()
×
3877
        for _, bit := range rows {
×
3878
                features.Set(lnwire.FeatureBit(bit))
×
3879
        }
×
3880

3881
        return features, nil
×
3882
}
3883

3884
// dbAddressType is an enum type that represents the different address types
3885
// that we store in the node_addresses table. The address type determines how
3886
// the address is to be serialised/deserialize.
3887
type dbAddressType uint8
3888

3889
const (
3890
        addressTypeIPv4   dbAddressType = 1
3891
        addressTypeIPv6   dbAddressType = 2
3892
        addressTypeTorV2  dbAddressType = 3
3893
        addressTypeTorV3  dbAddressType = 4
3894
        addressTypeDNS    dbAddressType = 5
3895
        addressTypeOpaque dbAddressType = math.MaxInt8
3896
)
3897

3898
// collectAddressRecords collects the addresses from the provided
3899
// net.Addr slice and returns a map of dbAddressType to a slice of address
3900
// strings.
3901
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3902
        error) {
×
3903

×
3904
        // Copy the nodes latest set of addresses.
×
3905
        newAddresses := map[dbAddressType][]string{
×
3906
                addressTypeIPv4:   {},
×
3907
                addressTypeIPv6:   {},
×
3908
                addressTypeTorV2:  {},
×
3909
                addressTypeTorV3:  {},
×
3910
                addressTypeDNS:    {},
×
3911
                addressTypeOpaque: {},
×
3912
        }
×
3913
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3914
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3915
        }
×
3916

3917
        for _, address := range addresses {
×
3918
                switch addr := address.(type) {
×
3919
                case *net.TCPAddr:
×
3920
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3921
                                addAddr(addressTypeIPv4, addr)
×
3922
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3923
                                addAddr(addressTypeIPv6, addr)
×
3924
                        } else {
×
3925
                                return nil, fmt.Errorf("unhandled IP "+
×
3926
                                        "address: %v", addr)
×
3927
                        }
×
3928

3929
                case *tor.OnionAddr:
×
3930
                        switch len(addr.OnionService) {
×
3931
                        case tor.V2Len:
×
3932
                                addAddr(addressTypeTorV2, addr)
×
3933
                        case tor.V3Len:
×
3934
                                addAddr(addressTypeTorV3, addr)
×
3935
                        default:
×
3936
                                return nil, fmt.Errorf("invalid length for " +
×
3937
                                        "a tor address")
×
3938
                        }
3939

3940
                case *lnwire.DNSAddress:
×
3941
                        addAddr(addressTypeDNS, addr)
×
3942

3943
                case *lnwire.OpaqueAddrs:
×
3944
                        addAddr(addressTypeOpaque, addr)
×
3945

3946
                default:
×
3947
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3948
                                addr)
×
3949
                }
3950
        }
3951

3952
        return newAddresses, nil
×
3953
}
3954

3955
// upsertNodeAddresses updates the node's addresses in the database. This
3956
// includes deleting any existing addresses and inserting the new set of
3957
// addresses. The deletion is necessary since the ordering of the addresses may
3958
// change, and we need to ensure that the database reflects the latest set of
3959
// addresses so that at the time of reconstructing the node announcement, the
3960
// order is preserved and the signature over the message remains valid.
3961
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3962
        addresses []net.Addr) error {
×
3963

×
3964
        // Delete any existing addresses for the node. This is required since
×
3965
        // even if the new set of addresses is the same, the ordering may have
×
3966
        // changed for a given address type.
×
3967
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3968
        if err != nil {
×
3969
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3970
                        nodeID, err)
×
3971
        }
×
3972

3973
        newAddresses, err := collectAddressRecords(addresses)
×
3974
        if err != nil {
×
3975
                return err
×
3976
        }
×
3977

3978
        // Any remaining entries in newAddresses are new addresses that need to
3979
        // be added to the database for the first time.
3980
        for addrType, addrList := range newAddresses {
×
3981
                for position, addr := range addrList {
×
3982
                        err := db.UpsertNodeAddress(
×
3983
                                ctx, sqlc.UpsertNodeAddressParams{
×
3984
                                        NodeID:   nodeID,
×
3985
                                        Type:     int16(addrType),
×
3986
                                        Address:  addr,
×
3987
                                        Position: int32(position),
×
3988
                                },
×
3989
                        )
×
3990
                        if err != nil {
×
3991
                                return fmt.Errorf("unable to insert "+
×
3992
                                        "node(%d) address(%v): %w", nodeID,
×
3993
                                        addr, err)
×
3994
                        }
×
3995
                }
3996
        }
3997

3998
        return nil
×
3999
}
4000

4001
// getNodeAddresses fetches the addresses for a node with the given DB ID.
4002
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
4003
        error) {
×
4004

×
4005
        // GetNodeAddresses ensures that the addresses for a given type are
×
4006
        // returned in the same order as they were inserted.
×
4007
        rows, err := db.GetNodeAddresses(ctx, id)
×
4008
        if err != nil {
×
4009
                return nil, err
×
4010
        }
×
4011

4012
        addresses := make([]net.Addr, 0, len(rows))
×
4013
        for _, row := range rows {
×
4014
                address := row.Address
×
4015

×
4016
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4017
                if err != nil {
×
4018
                        return nil, fmt.Errorf("unable to parse address "+
×
4019
                                "for node(%d): %v: %w", id, address, err)
×
4020
                }
×
4021

4022
                addresses = append(addresses, addr)
×
4023
        }
4024

4025
        // If we have no addresses, then we'll return nil instead of an
4026
        // empty slice.
4027
        if len(addresses) == 0 {
×
4028
                addresses = nil
×
4029
        }
×
4030

4031
        return addresses, nil
×
4032
}
4033

4034
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4035
// database. This includes updating any existing types, inserting any new types,
4036
// and deleting any types that are no longer present.
4037
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4038
        nodeID int64, extraFields map[uint64][]byte) error {
×
4039

×
4040
        // Get any existing extra signed fields for the node.
×
4041
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4042
        if err != nil {
×
4043
                return err
×
4044
        }
×
4045

4046
        // Make a lookup map of the existing field types so that we can use it
4047
        // to keep track of any fields we should delete.
4048
        m := make(map[uint64]bool)
×
4049
        for _, field := range existingFields {
×
4050
                m[uint64(field.Type)] = true
×
4051
        }
×
4052

4053
        // For all the new fields, we'll upsert them and remove them from the
4054
        // map of existing fields.
4055
        for tlvType, value := range extraFields {
×
4056
                err = db.UpsertNodeExtraType(
×
4057
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4058
                                NodeID: nodeID,
×
4059
                                Type:   int64(tlvType),
×
4060
                                Value:  value,
×
4061
                        },
×
4062
                )
×
4063
                if err != nil {
×
4064
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4065
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4066
                }
×
4067

4068
                // Remove the field from the map of existing fields if it was
4069
                // present.
4070
                delete(m, tlvType)
×
4071
        }
4072

4073
        // For all the fields that are left in the map of existing fields, we'll
4074
        // delete them as they are no longer present in the new set of fields.
4075
        for tlvType := range m {
×
4076
                err = db.DeleteExtraNodeType(
×
4077
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4078
                                NodeID: nodeID,
×
4079
                                Type:   int64(tlvType),
×
4080
                        },
×
4081
                )
×
4082
                if err != nil {
×
4083
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4084
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4085
                }
×
4086
        }
4087

4088
        return nil
×
4089
}
4090

4091
// srcNodeInfo holds the information about the source node of the graph.
4092
type srcNodeInfo struct {
4093
        // id is the DB level ID of the source node entry in the "nodes" table.
4094
        id int64
4095

4096
        // pub is the public key of the source node.
4097
        pub route.Vertex
4098
}
4099

4100
// sourceNode returns the DB node ID and pub key of the source node for the
4101
// specified protocol version.
4102
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4103
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4104

×
4105
        s.srcNodeMu.Lock()
×
4106
        defer s.srcNodeMu.Unlock()
×
4107

×
4108
        // If we already have the source node ID and pub key cached, then
×
4109
        // return them.
×
4110
        if info, ok := s.srcNodes[version]; ok {
×
4111
                return info.id, info.pub, nil
×
4112
        }
×
4113

4114
        var pubKey route.Vertex
×
4115

×
4116
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4117
        if err != nil {
×
4118
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4119
                        err)
×
4120
        }
×
4121

4122
        if len(nodes) == 0 {
×
4123
                return 0, pubKey, ErrSourceNodeNotSet
×
4124
        } else if len(nodes) > 1 {
×
4125
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4126
                        "protocol %s found", version)
×
4127
        }
×
4128

4129
        copy(pubKey[:], nodes[0].PubKey)
×
4130

×
4131
        s.srcNodes[version] = &srcNodeInfo{
×
4132
                id:  nodes[0].NodeID,
×
4133
                pub: pubKey,
×
4134
        }
×
4135

×
4136
        return nodes[0].NodeID, pubKey, nil
×
4137
}
4138

4139
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4140
// This then produces a map from TLV type to value. If the input is not a
4141
// valid TLV stream, then an error is returned.
4142
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4143
        r := bytes.NewReader(data)
×
4144

×
4145
        tlvStream, err := tlv.NewStream()
×
4146
        if err != nil {
×
4147
                return nil, err
×
4148
        }
×
4149

4150
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4151
        // pass it into the P2P decoding variant.
4152
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4153
        if err != nil {
×
4154
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4155
        }
×
4156
        if len(parsedTypes) == 0 {
×
4157
                return nil, nil
×
4158
        }
×
4159

4160
        records := make(map[uint64][]byte)
×
4161
        for k, v := range parsedTypes {
×
4162
                records[uint64(k)] = v
×
4163
        }
×
4164

4165
        return records, nil
×
4166
}
4167

4168
// insertChannel inserts a new channel record into the database.
4169
func insertChannel(ctx context.Context, db SQLQueries,
4170
        edge *models.ChannelEdgeInfo) error {
×
4171

×
4172
        // Make sure that at least a "shell" entry for each node is present in
×
4173
        // the nodes table.
×
4174
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4175
        if err != nil {
×
4176
                return fmt.Errorf("unable to create shell node: %w", err)
×
4177
        }
×
4178

4179
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4180
        if err != nil {
×
4181
                return fmt.Errorf("unable to create shell node: %w", err)
×
4182
        }
×
4183

4184
        var capacity sql.NullInt64
×
4185
        if edge.Capacity != 0 {
×
4186
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4187
        }
×
4188

4189
        createParams := sqlc.CreateChannelParams{
×
4190
                Version:     int16(lnwire.GossipVersion1),
×
4191
                Scid:        channelIDToBytes(edge.ChannelID),
×
4192
                NodeID1:     node1DBID,
×
4193
                NodeID2:     node2DBID,
×
4194
                Outpoint:    edge.ChannelPoint.String(),
×
4195
                Capacity:    capacity,
×
4196
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4197
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4198
        }
×
4199

×
4200
        if edge.AuthProof != nil {
×
4201
                proof := edge.AuthProof
×
4202

×
4203
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4204
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4205
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4206
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4207
        }
×
4208

4209
        // Insert the new channel record.
4210
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4211
        if err != nil {
×
4212
                return err
×
4213
        }
×
4214

4215
        // Insert any channel features.
4216
        for feature := range edge.Features.Features() {
×
4217
                err = db.InsertChannelFeature(
×
4218
                        ctx, sqlc.InsertChannelFeatureParams{
×
4219
                                ChannelID:  dbChanID,
×
4220
                                FeatureBit: int32(feature),
×
4221
                        },
×
4222
                )
×
4223
                if err != nil {
×
4224
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4225
                                "feature(%v): %w", dbChanID, feature, err)
×
4226
                }
×
4227
        }
4228

4229
        // Finally, insert any extra TLV fields in the channel announcement.
4230
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4231
        if err != nil {
×
4232
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4233
                        err)
×
4234
        }
×
4235

4236
        for tlvType, value := range extra {
×
4237
                err := db.UpsertChannelExtraType(
×
4238
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4239
                                ChannelID: dbChanID,
×
4240
                                Type:      int64(tlvType),
×
4241
                                Value:     value,
×
4242
                        },
×
4243
                )
×
4244
                if err != nil {
×
4245
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4246
                                "extra signed field(%v): %w", edge.ChannelID,
×
4247
                                tlvType, err)
×
4248
                }
×
4249
        }
4250

4251
        return nil
×
4252
}
4253

4254
// maybeCreateShellNode checks if a shell node entry exists for the
4255
// given public key. If it does not exist, then a new shell node entry is
4256
// created. The ID of the node is returned. A shell node only has a protocol
4257
// version and public key persisted.
4258
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4259
        pubKey route.Vertex) (int64, error) {
×
4260

×
4261
        dbNode, err := db.GetNodeByPubKey(
×
4262
                ctx, sqlc.GetNodeByPubKeyParams{
×
4263
                        PubKey:  pubKey[:],
×
4264
                        Version: int16(lnwire.GossipVersion1),
×
4265
                },
×
4266
        )
×
4267
        // The node exists. Return the ID.
×
4268
        if err == nil {
×
4269
                return dbNode.ID, nil
×
4270
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4271
                return 0, err
×
4272
        }
×
4273

4274
        // Otherwise, the node does not exist, so we create a shell entry for
4275
        // it.
4276
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4277
                Version: int16(lnwire.GossipVersion1),
×
4278
                PubKey:  pubKey[:],
×
4279
        })
×
4280
        if err != nil {
×
4281
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4282
        }
×
4283

4284
        return id, nil
×
4285
}
4286

4287
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4288
// the database. This includes deleting any existing types and then inserting
4289
// the new types.
4290
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4291
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4292

×
4293
        // Delete all existing extra signed fields for the channel policy.
×
4294
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4295
        if err != nil {
×
4296
                return fmt.Errorf("unable to delete "+
×
4297
                        "existing policy extra signed fields for policy %d: %w",
×
4298
                        chanPolicyID, err)
×
4299
        }
×
4300

4301
        // Insert all new extra signed fields for the channel policy.
4302
        for tlvType, value := range extraFields {
×
4303
                err = db.UpsertChanPolicyExtraType(
×
4304
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4305
                                ChannelPolicyID: chanPolicyID,
×
4306
                                Type:            int64(tlvType),
×
4307
                                Value:           value,
×
4308
                        },
×
4309
                )
×
4310
                if err != nil {
×
4311
                        return fmt.Errorf("unable to insert "+
×
4312
                                "channel_policy(%d) extra signed field(%v): %w",
×
4313
                                chanPolicyID, tlvType, err)
×
4314
                }
×
4315
        }
4316

4317
        return nil
×
4318
}
4319

4320
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4321
// provided dbChanRow and also fetches any other required information
4322
// to construct the edge info.
4323
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4324
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4325
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4326

×
4327
        data, err := batchLoadChannelData(
×
4328
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4329
        )
×
4330
        if err != nil {
×
4331
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4332
                        err)
×
4333
        }
×
4334

4335
        return buildEdgeInfoWithBatchData(
×
4336
                cfg.ChainHash, dbChan, node1, node2, data,
×
4337
        )
×
4338
}
4339

4340
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4341
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4342
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4343
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4344

×
4345
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4346
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4347
                        dbChan.Version)
×
4348
        }
×
4349

4350
        // Use pre-loaded features and extras types.
4351
        fv := lnwire.EmptyFeatureVector()
×
4352
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4353
                for _, bit := range features {
×
4354
                        fv.Set(lnwire.FeatureBit(bit))
×
4355
                }
×
4356
        }
4357

4358
        var extras map[uint64][]byte
×
4359
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4360
        if exists {
×
4361
                extras = channelExtras
×
4362
        } else {
×
4363
                extras = make(map[uint64][]byte)
×
4364
        }
×
4365

4366
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4367
        if err != nil {
×
4368
                return nil, err
×
4369
        }
×
4370

4371
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4372
        if err != nil {
×
4373
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4374
                        "fields: %w", err)
×
4375
        }
×
4376
        if recs == nil {
×
4377
                recs = make([]byte, 0)
×
4378
        }
×
4379

4380
        var btcKey1, btcKey2 route.Vertex
×
4381
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4382
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4383

×
4384
        channel := &models.ChannelEdgeInfo{
×
4385
                ChainHash:        chain,
×
4386
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4387
                NodeKey1Bytes:    node1,
×
4388
                NodeKey2Bytes:    node2,
×
4389
                BitcoinKey1Bytes: btcKey1,
×
4390
                BitcoinKey2Bytes: btcKey2,
×
4391
                ChannelPoint:     *op,
×
4392
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4393
                Features:         fv,
×
4394
                ExtraOpaqueData:  recs,
×
4395
        }
×
4396

×
4397
        // We always set all the signatures at the same time, so we can
×
4398
        // safely check if one signature is present to determine if we have the
×
4399
        // rest of the signatures for the auth proof.
×
4400
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4401
                channel.AuthProof = &models.ChannelAuthProof{
×
4402
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4403
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4404
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4405
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4406
                }
×
4407
        }
×
4408

4409
        return channel, nil
×
4410
}
4411

4412
// buildNodeVertices is a helper that converts raw node public keys
4413
// into route.Vertex instances.
4414
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4415
        route.Vertex, error) {
×
4416

×
4417
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4418
        if err != nil {
×
4419
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4420
                        "create vertex from node1 pubkey: %w", err)
×
4421
        }
×
4422

4423
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4424
        if err != nil {
×
4425
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4426
                        "create vertex from node2 pubkey: %w", err)
×
4427
        }
×
4428

4429
        return node1Vertex, node2Vertex, nil
×
4430
}
4431

4432
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4433
// retrieves all the extra info required to build the complete
4434
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4435
// the provided sqlc.GraphChannelPolicy records are nil.
4436
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4437
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4438
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4439
        *models.ChannelEdgePolicy, error) {
×
4440

×
4441
        if dbPol1 == nil && dbPol2 == nil {
×
4442
                return nil, nil, nil
×
4443
        }
×
4444

4445
        var policyIDs = make([]int64, 0, 2)
×
4446
        if dbPol1 != nil {
×
4447
                policyIDs = append(policyIDs, dbPol1.ID)
×
4448
        }
×
4449
        if dbPol2 != nil {
×
4450
                policyIDs = append(policyIDs, dbPol2.ID)
×
4451
        }
×
4452

4453
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4454
        if err != nil {
×
4455
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4456
                        "data: %w", err)
×
4457
        }
×
4458

4459
        pol1, err := buildChanPolicyWithBatchData(
×
4460
                dbPol1, channelID, node2, batchData,
×
4461
        )
×
4462
        if err != nil {
×
4463
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4464
        }
×
4465

4466
        pol2, err := buildChanPolicyWithBatchData(
×
4467
                dbPol2, channelID, node1, batchData,
×
4468
        )
×
4469
        if err != nil {
×
4470
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4471
        }
×
4472

4473
        return pol1, pol2, nil
×
4474
}
4475

4476
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4477
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4478
// then nil is returned for it.
4479
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4480
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4481
        *models.CachedEdgePolicy, error) {
×
4482

×
4483
        var p1, p2 *models.CachedEdgePolicy
×
4484
        if dbPol1 != nil {
×
4485
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4486
                if err != nil {
×
4487
                        return nil, nil, err
×
4488
                }
×
4489

4490
                p1 = models.NewCachedPolicy(policy1)
×
4491
        }
4492
        if dbPol2 != nil {
×
4493
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4494
                if err != nil {
×
4495
                        return nil, nil, err
×
4496
                }
×
4497

4498
                p2 = models.NewCachedPolicy(policy2)
×
4499
        }
4500

4501
        return p1, p2, nil
×
4502
}
4503

4504
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4505
// provided sqlc.GraphChannelPolicy and other required information.
4506
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4507
        extras map[uint64][]byte,
4508
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4509

×
4510
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4511
        if err != nil {
×
4512
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4513
                        "fields: %w", err)
×
4514
        }
×
4515

4516
        var inboundFee fn.Option[lnwire.Fee]
×
4517
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4518
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4519

×
4520
                inboundFee = fn.Some(lnwire.Fee{
×
4521
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4522
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4523
                })
×
4524
        }
×
4525

4526
        return &models.ChannelEdgePolicy{
×
4527
                SigBytes:  dbPolicy.Signature,
×
4528
                ChannelID: channelID,
×
4529
                LastUpdate: time.Unix(
×
4530
                        dbPolicy.LastUpdate.Int64, 0,
×
4531
                ),
×
4532
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4533
                        dbPolicy.MessageFlags,
×
4534
                ),
×
4535
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4536
                        dbPolicy.ChannelFlags,
×
4537
                ),
×
4538
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4539
                MinHTLC: lnwire.MilliSatoshi(
×
4540
                        dbPolicy.MinHtlcMsat,
×
4541
                ),
×
4542
                MaxHTLC: lnwire.MilliSatoshi(
×
4543
                        dbPolicy.MaxHtlcMsat.Int64,
×
4544
                ),
×
4545
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4546
                        dbPolicy.BaseFeeMsat,
×
4547
                ),
×
4548
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4549
                ToNode:                    toNode,
×
4550
                InboundFee:                inboundFee,
×
4551
                ExtraOpaqueData:           recs,
×
4552
        }, nil
×
4553
}
4554

4555
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4556
// row which is expected to be a sqlc type that contains channel policy
4557
// information. It returns two policies, which may be nil if the policy
4558
// information is not present in the row.
4559
//
4560
//nolint:ll,dupl,funlen
4561
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4562
        *sqlc.GraphChannelPolicy, error) {
×
4563

×
4564
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4565
        switch r := row.(type) {
×
4566
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4567
                if r.Policy1Timelock.Valid {
×
4568
                        policy1 = &sqlc.GraphChannelPolicy{
×
4569
                                Timelock:                r.Policy1Timelock.Int32,
×
4570
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4571
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4572
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4573
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4574
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4575
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4576
                                Disabled:                r.Policy1Disabled,
×
4577
                                MessageFlags:            r.Policy1MessageFlags,
×
4578
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4579
                        }
×
4580
                }
×
4581
                if r.Policy2Timelock.Valid {
×
4582
                        policy2 = &sqlc.GraphChannelPolicy{
×
4583
                                Timelock:                r.Policy2Timelock.Int32,
×
4584
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4585
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4586
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4587
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4588
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4589
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4590
                                Disabled:                r.Policy2Disabled,
×
4591
                                MessageFlags:            r.Policy2MessageFlags,
×
4592
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4593
                        }
×
4594
                }
×
4595

4596
                return policy1, policy2, nil
×
4597

4598
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4599
                if r.Policy1ID.Valid {
×
4600
                        policy1 = &sqlc.GraphChannelPolicy{
×
4601
                                ID:                      r.Policy1ID.Int64,
×
4602
                                Version:                 r.Policy1Version.Int16,
×
4603
                                ChannelID:               r.GraphChannel.ID,
×
4604
                                NodeID:                  r.Policy1NodeID.Int64,
×
4605
                                Timelock:                r.Policy1Timelock.Int32,
×
4606
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4607
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4608
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4609
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4610
                                LastUpdate:              r.Policy1LastUpdate,
×
4611
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4612
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4613
                                Disabled:                r.Policy1Disabled,
×
4614
                                MessageFlags:            r.Policy1MessageFlags,
×
4615
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4616
                                Signature:               r.Policy1Signature,
×
4617
                        }
×
4618
                }
×
4619
                if r.Policy2ID.Valid {
×
4620
                        policy2 = &sqlc.GraphChannelPolicy{
×
4621
                                ID:                      r.Policy2ID.Int64,
×
4622
                                Version:                 r.Policy2Version.Int16,
×
4623
                                ChannelID:               r.GraphChannel.ID,
×
4624
                                NodeID:                  r.Policy2NodeID.Int64,
×
4625
                                Timelock:                r.Policy2Timelock.Int32,
×
4626
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4627
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4628
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4629
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4630
                                LastUpdate:              r.Policy2LastUpdate,
×
4631
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4632
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4633
                                Disabled:                r.Policy2Disabled,
×
4634
                                MessageFlags:            r.Policy2MessageFlags,
×
4635
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4636
                                Signature:               r.Policy2Signature,
×
4637
                        }
×
4638
                }
×
4639

4640
                return policy1, policy2, nil
×
4641

4642
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4643
                if r.Policy1ID.Valid {
×
4644
                        policy1 = &sqlc.GraphChannelPolicy{
×
4645
                                ID:                      r.Policy1ID.Int64,
×
4646
                                Version:                 r.Policy1Version.Int16,
×
4647
                                ChannelID:               r.GraphChannel.ID,
×
4648
                                NodeID:                  r.Policy1NodeID.Int64,
×
4649
                                Timelock:                r.Policy1Timelock.Int32,
×
4650
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4651
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4652
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4653
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4654
                                LastUpdate:              r.Policy1LastUpdate,
×
4655
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4656
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4657
                                Disabled:                r.Policy1Disabled,
×
4658
                                MessageFlags:            r.Policy1MessageFlags,
×
4659
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4660
                                Signature:               r.Policy1Signature,
×
4661
                        }
×
4662
                }
×
4663
                if r.Policy2ID.Valid {
×
4664
                        policy2 = &sqlc.GraphChannelPolicy{
×
4665
                                ID:                      r.Policy2ID.Int64,
×
4666
                                Version:                 r.Policy2Version.Int16,
×
4667
                                ChannelID:               r.GraphChannel.ID,
×
4668
                                NodeID:                  r.Policy2NodeID.Int64,
×
4669
                                Timelock:                r.Policy2Timelock.Int32,
×
4670
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4671
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4672
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4673
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4674
                                LastUpdate:              r.Policy2LastUpdate,
×
4675
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4676
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4677
                                Disabled:                r.Policy2Disabled,
×
4678
                                MessageFlags:            r.Policy2MessageFlags,
×
4679
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4680
                                Signature:               r.Policy2Signature,
×
4681
                        }
×
4682
                }
×
4683

4684
                return policy1, policy2, nil
×
4685

4686
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4687
                if r.Policy1ID.Valid {
×
4688
                        policy1 = &sqlc.GraphChannelPolicy{
×
4689
                                ID:                      r.Policy1ID.Int64,
×
4690
                                Version:                 r.Policy1Version.Int16,
×
4691
                                ChannelID:               r.GraphChannel.ID,
×
4692
                                NodeID:                  r.Policy1NodeID.Int64,
×
4693
                                Timelock:                r.Policy1Timelock.Int32,
×
4694
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4695
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4696
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4697
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4698
                                LastUpdate:              r.Policy1LastUpdate,
×
4699
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4700
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4701
                                Disabled:                r.Policy1Disabled,
×
4702
                                MessageFlags:            r.Policy1MessageFlags,
×
4703
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4704
                                Signature:               r.Policy1Signature,
×
4705
                        }
×
4706
                }
×
4707
                if r.Policy2ID.Valid {
×
4708
                        policy2 = &sqlc.GraphChannelPolicy{
×
4709
                                ID:                      r.Policy2ID.Int64,
×
4710
                                Version:                 r.Policy2Version.Int16,
×
4711
                                ChannelID:               r.GraphChannel.ID,
×
4712
                                NodeID:                  r.Policy2NodeID.Int64,
×
4713
                                Timelock:                r.Policy2Timelock.Int32,
×
4714
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4715
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4716
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4717
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4718
                                LastUpdate:              r.Policy2LastUpdate,
×
4719
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4720
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4721
                                Disabled:                r.Policy2Disabled,
×
4722
                                MessageFlags:            r.Policy2MessageFlags,
×
4723
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4724
                                Signature:               r.Policy2Signature,
×
4725
                        }
×
4726
                }
×
4727

4728
                return policy1, policy2, nil
×
4729

4730
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4731
                if r.Policy1ID.Valid {
×
4732
                        policy1 = &sqlc.GraphChannelPolicy{
×
4733
                                ID:                      r.Policy1ID.Int64,
×
4734
                                Version:                 r.Policy1Version.Int16,
×
4735
                                ChannelID:               r.GraphChannel.ID,
×
4736
                                NodeID:                  r.Policy1NodeID.Int64,
×
4737
                                Timelock:                r.Policy1Timelock.Int32,
×
4738
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4739
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4740
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4741
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4742
                                LastUpdate:              r.Policy1LastUpdate,
×
4743
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4744
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4745
                                Disabled:                r.Policy1Disabled,
×
4746
                                MessageFlags:            r.Policy1MessageFlags,
×
4747
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4748
                                Signature:               r.Policy1Signature,
×
4749
                        }
×
4750
                }
×
4751
                if r.Policy2ID.Valid {
×
4752
                        policy2 = &sqlc.GraphChannelPolicy{
×
4753
                                ID:                      r.Policy2ID.Int64,
×
4754
                                Version:                 r.Policy2Version.Int16,
×
4755
                                ChannelID:               r.GraphChannel.ID,
×
4756
                                NodeID:                  r.Policy2NodeID.Int64,
×
4757
                                Timelock:                r.Policy2Timelock.Int32,
×
4758
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4759
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4760
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4761
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4762
                                LastUpdate:              r.Policy2LastUpdate,
×
4763
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4764
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4765
                                Disabled:                r.Policy2Disabled,
×
4766
                                MessageFlags:            r.Policy2MessageFlags,
×
4767
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4768
                                Signature:               r.Policy2Signature,
×
4769
                        }
×
4770
                }
×
4771

4772
                return policy1, policy2, nil
×
4773

4774
        case sqlc.ListChannelsForNodeIDsRow:
×
4775
                if r.Policy1ID.Valid {
×
4776
                        policy1 = &sqlc.GraphChannelPolicy{
×
4777
                                ID:                      r.Policy1ID.Int64,
×
4778
                                Version:                 r.Policy1Version.Int16,
×
4779
                                ChannelID:               r.GraphChannel.ID,
×
4780
                                NodeID:                  r.Policy1NodeID.Int64,
×
4781
                                Timelock:                r.Policy1Timelock.Int32,
×
4782
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4783
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4784
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4785
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4786
                                LastUpdate:              r.Policy1LastUpdate,
×
4787
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4788
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4789
                                Disabled:                r.Policy1Disabled,
×
4790
                                MessageFlags:            r.Policy1MessageFlags,
×
4791
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4792
                                Signature:               r.Policy1Signature,
×
4793
                        }
×
4794
                }
×
4795
                if r.Policy2ID.Valid {
×
4796
                        policy2 = &sqlc.GraphChannelPolicy{
×
4797
                                ID:                      r.Policy2ID.Int64,
×
4798
                                Version:                 r.Policy2Version.Int16,
×
4799
                                ChannelID:               r.GraphChannel.ID,
×
4800
                                NodeID:                  r.Policy2NodeID.Int64,
×
4801
                                Timelock:                r.Policy2Timelock.Int32,
×
4802
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4803
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4804
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4805
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4806
                                LastUpdate:              r.Policy2LastUpdate,
×
4807
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4808
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4809
                                Disabled:                r.Policy2Disabled,
×
4810
                                MessageFlags:            r.Policy2MessageFlags,
×
4811
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4812
                                Signature:               r.Policy2Signature,
×
4813
                        }
×
4814
                }
×
4815

4816
                return policy1, policy2, nil
×
4817

4818
        case sqlc.ListChannelsByNodeIDRow:
×
4819
                if r.Policy1ID.Valid {
×
4820
                        policy1 = &sqlc.GraphChannelPolicy{
×
4821
                                ID:                      r.Policy1ID.Int64,
×
4822
                                Version:                 r.Policy1Version.Int16,
×
4823
                                ChannelID:               r.GraphChannel.ID,
×
4824
                                NodeID:                  r.Policy1NodeID.Int64,
×
4825
                                Timelock:                r.Policy1Timelock.Int32,
×
4826
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4827
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4828
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4829
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4830
                                LastUpdate:              r.Policy1LastUpdate,
×
4831
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4832
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4833
                                Disabled:                r.Policy1Disabled,
×
4834
                                MessageFlags:            r.Policy1MessageFlags,
×
4835
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4836
                                Signature:               r.Policy1Signature,
×
4837
                        }
×
4838
                }
×
4839
                if r.Policy2ID.Valid {
×
4840
                        policy2 = &sqlc.GraphChannelPolicy{
×
4841
                                ID:                      r.Policy2ID.Int64,
×
4842
                                Version:                 r.Policy2Version.Int16,
×
4843
                                ChannelID:               r.GraphChannel.ID,
×
4844
                                NodeID:                  r.Policy2NodeID.Int64,
×
4845
                                Timelock:                r.Policy2Timelock.Int32,
×
4846
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4847
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4848
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4849
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4850
                                LastUpdate:              r.Policy2LastUpdate,
×
4851
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4852
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4853
                                Disabled:                r.Policy2Disabled,
×
4854
                                MessageFlags:            r.Policy2MessageFlags,
×
4855
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4856
                                Signature:               r.Policy2Signature,
×
4857
                        }
×
4858
                }
×
4859

4860
                return policy1, policy2, nil
×
4861

4862
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4863
                if r.Policy1ID.Valid {
×
4864
                        policy1 = &sqlc.GraphChannelPolicy{
×
4865
                                ID:                      r.Policy1ID.Int64,
×
4866
                                Version:                 r.Policy1Version.Int16,
×
4867
                                ChannelID:               r.GraphChannel.ID,
×
4868
                                NodeID:                  r.Policy1NodeID.Int64,
×
4869
                                Timelock:                r.Policy1Timelock.Int32,
×
4870
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4871
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4872
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4873
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4874
                                LastUpdate:              r.Policy1LastUpdate,
×
4875
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4876
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4877
                                Disabled:                r.Policy1Disabled,
×
4878
                                MessageFlags:            r.Policy1MessageFlags,
×
4879
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4880
                                Signature:               r.Policy1Signature,
×
4881
                        }
×
4882
                }
×
4883
                if r.Policy2ID.Valid {
×
4884
                        policy2 = &sqlc.GraphChannelPolicy{
×
4885
                                ID:                      r.Policy2ID.Int64,
×
4886
                                Version:                 r.Policy2Version.Int16,
×
4887
                                ChannelID:               r.GraphChannel.ID,
×
4888
                                NodeID:                  r.Policy2NodeID.Int64,
×
4889
                                Timelock:                r.Policy2Timelock.Int32,
×
4890
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4891
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4892
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4893
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4894
                                LastUpdate:              r.Policy2LastUpdate,
×
4895
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4896
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4897
                                Disabled:                r.Policy2Disabled,
×
4898
                                MessageFlags:            r.Policy2MessageFlags,
×
4899
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4900
                                Signature:               r.Policy2Signature,
×
4901
                        }
×
4902
                }
×
4903

4904
                return policy1, policy2, nil
×
4905

4906
        case sqlc.GetChannelsByIDsRow:
×
4907
                if r.Policy1ID.Valid {
×
4908
                        policy1 = &sqlc.GraphChannelPolicy{
×
4909
                                ID:                      r.Policy1ID.Int64,
×
4910
                                Version:                 r.Policy1Version.Int16,
×
4911
                                ChannelID:               r.GraphChannel.ID,
×
4912
                                NodeID:                  r.Policy1NodeID.Int64,
×
4913
                                Timelock:                r.Policy1Timelock.Int32,
×
4914
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4915
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4916
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4917
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4918
                                LastUpdate:              r.Policy1LastUpdate,
×
4919
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4920
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4921
                                Disabled:                r.Policy1Disabled,
×
4922
                                MessageFlags:            r.Policy1MessageFlags,
×
4923
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4924
                                Signature:               r.Policy1Signature,
×
4925
                        }
×
4926
                }
×
4927
                if r.Policy2ID.Valid {
×
4928
                        policy2 = &sqlc.GraphChannelPolicy{
×
4929
                                ID:                      r.Policy2ID.Int64,
×
4930
                                Version:                 r.Policy2Version.Int16,
×
4931
                                ChannelID:               r.GraphChannel.ID,
×
4932
                                NodeID:                  r.Policy2NodeID.Int64,
×
4933
                                Timelock:                r.Policy2Timelock.Int32,
×
4934
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4935
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4936
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4937
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4938
                                LastUpdate:              r.Policy2LastUpdate,
×
4939
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4940
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4941
                                Disabled:                r.Policy2Disabled,
×
4942
                                MessageFlags:            r.Policy2MessageFlags,
×
4943
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4944
                                Signature:               r.Policy2Signature,
×
4945
                        }
×
4946
                }
×
4947

4948
                return policy1, policy2, nil
×
4949

4950
        default:
×
4951
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4952
                        "extractChannelPolicies: %T", r)
×
4953
        }
4954
}
4955

4956
// channelIDToBytes converts a channel ID (SCID) to a byte array
4957
// representation.
4958
func channelIDToBytes(channelID uint64) []byte {
×
4959
        var chanIDB [8]byte
×
4960
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4961

×
4962
        return chanIDB[:]
×
4963
}
×
4964

4965
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4966
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4967
        if len(addresses) == 0 {
×
4968
                return nil, nil
×
4969
        }
×
4970

4971
        result := make([]net.Addr, 0, len(addresses))
×
4972
        for _, addr := range addresses {
×
4973
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4974
                if err != nil {
×
4975
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4976
                                "of type %d: %w", addr.address, addr.addrType,
×
4977
                                err)
×
4978
                }
×
4979
                if netAddr != nil {
×
4980
                        result = append(result, netAddr)
×
4981
                }
×
4982
        }
4983

4984
        // If we have no valid addresses, return nil instead of empty slice.
4985
        if len(result) == 0 {
×
4986
                return nil, nil
×
4987
        }
×
4988

4989
        return result, nil
×
4990
}
4991

4992
// parseAddress parses the given address string based on the address type
4993
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4994
// and opaque addresses.
4995
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4996
        switch addrType {
×
4997
        case addressTypeIPv4:
×
4998
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4999
                if err != nil {
×
5000
                        return nil, err
×
5001
                }
×
5002

5003
                tcp.IP = tcp.IP.To4()
×
5004

×
5005
                return tcp, nil
×
5006

5007
        case addressTypeIPv6:
×
5008
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
5009
                if err != nil {
×
5010
                        return nil, err
×
5011
                }
×
5012

5013
                return tcp, nil
×
5014

5015
        case addressTypeTorV3, addressTypeTorV2:
×
5016
                service, portStr, err := net.SplitHostPort(address)
×
5017
                if err != nil {
×
5018
                        return nil, fmt.Errorf("unable to split tor "+
×
5019
                                "address: %v", address)
×
5020
                }
×
5021

5022
                port, err := strconv.Atoi(portStr)
×
5023
                if err != nil {
×
5024
                        return nil, err
×
5025
                }
×
5026

5027
                return &tor.OnionAddr{
×
5028
                        OnionService: service,
×
5029
                        Port:         port,
×
5030
                }, nil
×
5031

5032
        case addressTypeDNS:
×
5033
                hostname, portStr, err := net.SplitHostPort(address)
×
5034
                if err != nil {
×
5035
                        return nil, fmt.Errorf("unable to split DNS "+
×
5036
                                "address: %v", address)
×
5037
                }
×
5038

5039
                port, err := strconv.Atoi(portStr)
×
5040
                if err != nil {
×
5041
                        return nil, err
×
5042
                }
×
5043

5044
                return &lnwire.DNSAddress{
×
5045
                        Hostname: hostname,
×
5046
                        Port:     uint16(port),
×
5047
                }, nil
×
5048

5049
        case addressTypeOpaque:
×
5050
                opaque, err := hex.DecodeString(address)
×
5051
                if err != nil {
×
5052
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5053
                                "address: %v", address)
×
5054
                }
×
5055

5056
                return &lnwire.OpaqueAddrs{
×
5057
                        Payload: opaque,
×
5058
                }, nil
×
5059

5060
        default:
×
5061
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5062
        }
5063
}
5064

5065
// batchNodeData holds all the related data for a batch of nodes.
5066
type batchNodeData struct {
5067
        // features is a map from a DB node ID to the feature bits for that
5068
        // node.
5069
        features map[int64][]int
5070

5071
        // addresses is a map from a DB node ID to the node's addresses.
5072
        addresses map[int64][]nodeAddress
5073

5074
        // extraFields is a map from a DB node ID to the extra signed fields
5075
        // for that node.
5076
        extraFields map[int64]map[uint64][]byte
5077
}
5078

5079
// nodeAddress holds the address type, position and address string for a
5080
// node. This is used to batch the fetching of node addresses.
5081
type nodeAddress struct {
5082
        addrType dbAddressType
5083
        position int32
5084
        address  string
5085
}
5086

5087
// batchLoadNodeData loads all related data for a batch of node IDs using the
5088
// provided SQLQueries interface. It returns a batchNodeData instance containing
5089
// the node features, addresses and extra signed fields.
5090
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5091
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5092

×
5093
        // Batch load the node features.
×
5094
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5095
        if err != nil {
×
5096
                return nil, fmt.Errorf("unable to batch load node "+
×
5097
                        "features: %w", err)
×
5098
        }
×
5099

5100
        // Batch load the node addresses.
5101
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5102
        if err != nil {
×
5103
                return nil, fmt.Errorf("unable to batch load node "+
×
5104
                        "addresses: %w", err)
×
5105
        }
×
5106

5107
        // Batch load the node extra signed fields.
5108
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5109
        if err != nil {
×
5110
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5111
                        "signed fields: %w", err)
×
5112
        }
×
5113

5114
        return &batchNodeData{
×
5115
                features:    features,
×
5116
                addresses:   addrs,
×
5117
                extraFields: extraTypes,
×
5118
        }, nil
×
5119
}
5120

5121
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5122
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5123
func batchLoadNodeFeaturesHelper(ctx context.Context,
5124
        cfg *sqldb.QueryConfig, db SQLQueries,
5125
        nodeIDs []int64) (map[int64][]int, error) {
×
5126

×
5127
        features := make(map[int64][]int)
×
5128

×
5129
        return features, sqldb.ExecuteBatchQuery(
×
5130
                ctx, cfg, nodeIDs,
×
5131
                func(id int64) int64 {
×
5132
                        return id
×
5133
                },
×
5134
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5135
                        error) {
×
5136

×
5137
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5138
                },
×
5139
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5140
                        features[feature.NodeID] = append(
×
5141
                                features[feature.NodeID],
×
5142
                                int(feature.FeatureBit),
×
5143
                        )
×
5144

×
5145
                        return nil
×
5146
                },
×
5147
        )
5148
}
5149

5150
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5151
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5152
// node ID to a slice of nodeAddress structs.
5153
func batchLoadNodeAddressesHelper(ctx context.Context,
5154
        cfg *sqldb.QueryConfig, db SQLQueries,
5155
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5156

×
5157
        addrs := make(map[int64][]nodeAddress)
×
5158

×
5159
        return addrs, sqldb.ExecuteBatchQuery(
×
5160
                ctx, cfg, nodeIDs,
×
5161
                func(id int64) int64 {
×
5162
                        return id
×
5163
                },
×
5164
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5165
                        error) {
×
5166

×
5167
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5168
                },
×
5169
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5170
                        addrs[addr.NodeID] = append(
×
5171
                                addrs[addr.NodeID], nodeAddress{
×
5172
                                        addrType: dbAddressType(addr.Type),
×
5173
                                        position: addr.Position,
×
5174
                                        address:  addr.Address,
×
5175
                                },
×
5176
                        )
×
5177

×
5178
                        return nil
×
5179
                },
×
5180
        )
5181
}
5182

5183
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5184
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5185
// query.
5186
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5187
        cfg *sqldb.QueryConfig, db SQLQueries,
5188
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5189

×
5190
        extraFields := make(map[int64]map[uint64][]byte)
×
5191

×
5192
        callback := func(ctx context.Context,
×
5193
                field sqlc.GraphNodeExtraType) error {
×
5194

×
5195
                if extraFields[field.NodeID] == nil {
×
5196
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5197
                }
×
5198
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5199

×
5200
                return nil
×
5201
        }
5202

5203
        return extraFields, sqldb.ExecuteBatchQuery(
×
5204
                ctx, cfg, nodeIDs,
×
5205
                func(id int64) int64 {
×
5206
                        return id
×
5207
                },
×
5208
                func(ctx context.Context, ids []int64) (
5209
                        []sqlc.GraphNodeExtraType, error) {
×
5210

×
5211
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5212
                },
×
5213
                callback,
5214
        )
5215
}
5216

5217
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5218
// from the provided sqlc.GraphChannelPolicy records and the
5219
// provided batchChannelData.
5220
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5221
        channelID uint64, node1, node2 route.Vertex,
5222
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5223
        *models.ChannelEdgePolicy, error) {
×
5224

×
5225
        pol1, err := buildChanPolicyWithBatchData(
×
5226
                dbPol1, channelID, node2, batchData,
×
5227
        )
×
5228
        if err != nil {
×
5229
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5230
        }
×
5231

5232
        pol2, err := buildChanPolicyWithBatchData(
×
5233
                dbPol2, channelID, node1, batchData,
×
5234
        )
×
5235
        if err != nil {
×
5236
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5237
        }
×
5238

5239
        return pol1, pol2, nil
×
5240
}
5241

5242
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5243
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5244
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5245
        channelID uint64, toNode route.Vertex,
5246
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5247

×
5248
        if dbPol == nil {
×
5249
                return nil, nil
×
5250
        }
×
5251

5252
        var dbPol1Extras map[uint64][]byte
×
5253
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5254
                dbPol1Extras = extras
×
5255
        } else {
×
5256
                dbPol1Extras = make(map[uint64][]byte)
×
5257
        }
×
5258

5259
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5260
}
5261

5262
// batchChannelData holds all the related data for a batch of channels.
5263
type batchChannelData struct {
5264
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5265
        chanfeatures map[int64][]int
5266

5267
        // chanExtras is a map from DB channel ID to a map of TLV type to
5268
        // extra signed field bytes.
5269
        chanExtraTypes map[int64]map[uint64][]byte
5270

5271
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5272
        // to extra signed field bytes.
5273
        policyExtras map[int64]map[uint64][]byte
5274
}
5275

5276
// batchLoadChannelData loads all related data for batches of channels and
5277
// policies.
5278
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5279
        db SQLQueries, channelIDs []int64,
5280
        policyIDs []int64) (*batchChannelData, error) {
×
5281

×
5282
        batchData := &batchChannelData{
×
5283
                chanfeatures:   make(map[int64][]int),
×
5284
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5285
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5286
        }
×
5287

×
5288
        // Batch load channel features and extras
×
5289
        var err error
×
5290
        if len(channelIDs) > 0 {
×
5291
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5292
                        ctx, cfg, db, channelIDs,
×
5293
                )
×
5294
                if err != nil {
×
5295
                        return nil, fmt.Errorf("unable to batch load "+
×
5296
                                "channel features: %w", err)
×
5297
                }
×
5298

5299
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5300
                        ctx, cfg, db, channelIDs,
×
5301
                )
×
5302
                if err != nil {
×
5303
                        return nil, fmt.Errorf("unable to batch load "+
×
5304
                                "channel extras: %w", err)
×
5305
                }
×
5306
        }
5307

5308
        if len(policyIDs) > 0 {
×
5309
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5310
                        ctx, cfg, db, policyIDs,
×
5311
                )
×
5312
                if err != nil {
×
5313
                        return nil, fmt.Errorf("unable to batch load "+
×
5314
                                "policy extras: %w", err)
×
5315
                }
×
5316
                batchData.policyExtras = policyExtras
×
5317
        }
5318

5319
        return batchData, nil
×
5320
}
5321

5322
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5323
// channel IDs using ExecuteBatchQuery wrapper around the
5324
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5325
// slice of feature bits.
5326
func batchLoadChannelFeaturesHelper(ctx context.Context,
5327
        cfg *sqldb.QueryConfig, db SQLQueries,
5328
        channelIDs []int64) (map[int64][]int, error) {
×
5329

×
5330
        features := make(map[int64][]int)
×
5331

×
5332
        return features, sqldb.ExecuteBatchQuery(
×
5333
                ctx, cfg, channelIDs,
×
5334
                func(id int64) int64 {
×
5335
                        return id
×
5336
                },
×
5337
                func(ctx context.Context,
5338
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5339

×
5340
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5341
                },
×
5342
                func(ctx context.Context,
5343
                        feature sqlc.GraphChannelFeature) error {
×
5344

×
5345
                        features[feature.ChannelID] = append(
×
5346
                                features[feature.ChannelID],
×
5347
                                int(feature.FeatureBit),
×
5348
                        )
×
5349

×
5350
                        return nil
×
5351
                },
×
5352
        )
5353
}
5354

5355
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5356
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5357
// query. It returns a map from DB channel ID to a map of TLV type to extra
5358
// signed field bytes.
5359
func batchLoadChannelExtrasHelper(ctx context.Context,
5360
        cfg *sqldb.QueryConfig, db SQLQueries,
5361
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5362

×
5363
        extras := make(map[int64]map[uint64][]byte)
×
5364

×
5365
        cb := func(ctx context.Context,
×
5366
                extra sqlc.GraphChannelExtraType) error {
×
5367

×
5368
                if extras[extra.ChannelID] == nil {
×
5369
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5370
                }
×
5371
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5372

×
5373
                return nil
×
5374
        }
5375

5376
        return extras, sqldb.ExecuteBatchQuery(
×
5377
                ctx, cfg, channelIDs,
×
5378
                func(id int64) int64 {
×
5379
                        return id
×
5380
                },
×
5381
                func(ctx context.Context,
5382
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5383

×
5384
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5385
                }, cb,
×
5386
        )
5387
}
5388

5389
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5390
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5391
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5392
// a map of TLV type to extra signed field bytes.
5393
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5394
        cfg *sqldb.QueryConfig, db SQLQueries,
5395
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5396

×
5397
        extras := make(map[int64]map[uint64][]byte)
×
5398

×
5399
        return extras, sqldb.ExecuteBatchQuery(
×
5400
                ctx, cfg, policyIDs,
×
5401
                func(id int64) int64 {
×
5402
                        return id
×
5403
                },
×
5404
                func(ctx context.Context, ids []int64) (
5405
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5406

×
5407
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5408
                },
×
5409
                func(ctx context.Context,
5410
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5411

×
5412
                        if extras[row.PolicyID] == nil {
×
5413
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5414
                        }
×
5415
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5416

×
5417
                        return nil
×
5418
                },
5419
        )
5420
}
5421

5422
// forEachNodePaginated executes a paginated query to process each node in the
5423
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5424
// and applies the provided processNode function to each node.
5425
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5426
        db SQLQueries, protocol lnwire.GossipVersion,
5427
        processNode func(context.Context, int64,
5428
                *models.Node) error) error {
×
5429

×
5430
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5431
                limit int32) ([]sqlc.GraphNode, error) {
×
5432

×
5433
                return db.ListNodesPaginated(
×
5434
                        ctx, sqlc.ListNodesPaginatedParams{
×
5435
                                Version: int16(protocol),
×
5436
                                ID:      lastID,
×
5437
                                Limit:   limit,
×
5438
                        },
×
5439
                )
×
5440
        }
×
5441

5442
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5443
                return node.ID
×
5444
        }
×
5445

5446
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5447
                return node.ID, nil
×
5448
        }
×
5449

5450
        batchQueryFunc := func(ctx context.Context,
×
5451
                nodeIDs []int64) (*batchNodeData, error) {
×
5452

×
5453
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5454
        }
×
5455

5456
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5457
                batchData *batchNodeData) error {
×
5458

×
5459
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5460
                if err != nil {
×
5461
                        return fmt.Errorf("unable to build "+
×
5462
                                "node(id=%d): %w", dbNode.ID, err)
×
5463
                }
×
5464

5465
                return processNode(ctx, dbNode.ID, node)
×
5466
        }
5467

5468
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5469
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5470
                collectFunc, batchQueryFunc, processItem,
×
5471
        )
×
5472
}
5473

5474
// forEachChannelWithPolicies executes a paginated query to process each channel
5475
// with policies in the graph.
5476
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5477
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5478
                *models.ChannelEdgePolicy,
5479
                *models.ChannelEdgePolicy) error) error {
×
5480

×
5481
        type channelBatchIDs struct {
×
5482
                channelID int64
×
5483
                policyIDs []int64
×
5484
        }
×
5485

×
5486
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5487
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5488
                error) {
×
5489

×
5490
                return db.ListChannelsWithPoliciesPaginated(
×
5491
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5492
                                Version: int16(lnwire.GossipVersion1),
×
5493
                                ID:      lastID,
×
5494
                                Limit:   limit,
×
5495
                        },
×
5496
                )
×
5497
        }
×
5498

5499
        extractPageCursor := func(
×
5500
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5501

×
5502
                return row.GraphChannel.ID
×
5503
        }
×
5504

5505
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5506
                channelBatchIDs, error) {
×
5507

×
5508
                ids := channelBatchIDs{
×
5509
                        channelID: row.GraphChannel.ID,
×
5510
                }
×
5511

×
5512
                // Extract policy IDs from the row.
×
5513
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5514
                if err != nil {
×
5515
                        return ids, err
×
5516
                }
×
5517

5518
                if dbPol1 != nil {
×
5519
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5520
                }
×
5521
                if dbPol2 != nil {
×
5522
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5523
                }
×
5524

5525
                return ids, nil
×
5526
        }
5527

5528
        batchDataFunc := func(ctx context.Context,
×
5529
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5530

×
5531
                // Separate channel IDs from policy IDs.
×
5532
                var (
×
5533
                        channelIDs = make([]int64, len(allIDs))
×
5534
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5535
                )
×
5536

×
5537
                for i, ids := range allIDs {
×
5538
                        channelIDs[i] = ids.channelID
×
5539
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5540
                }
×
5541

5542
                return batchLoadChannelData(
×
5543
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5544
                )
×
5545
        }
5546

5547
        processItem := func(ctx context.Context,
×
5548
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5549
                batchData *batchChannelData) error {
×
5550

×
5551
                node1, node2, err := buildNodeVertices(
×
5552
                        row.Node1Pubkey, row.Node2Pubkey,
×
5553
                )
×
5554
                if err != nil {
×
5555
                        return err
×
5556
                }
×
5557

5558
                edge, err := buildEdgeInfoWithBatchData(
×
5559
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5560
                        batchData,
×
5561
                )
×
5562
                if err != nil {
×
5563
                        return fmt.Errorf("unable to build channel info: %w",
×
5564
                                err)
×
5565
                }
×
5566

5567
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5568
                if err != nil {
×
5569
                        return err
×
5570
                }
×
5571

5572
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5573
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5574
                )
×
5575
                if err != nil {
×
5576
                        return err
×
5577
                }
×
5578

5579
                return processChannel(edge, p1, p2)
×
5580
        }
5581

5582
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5583
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5584
                collectFunc, batchDataFunc, processItem,
×
5585
        )
×
5586
}
5587

5588
// buildDirectedChannel builds a DirectedChannel instance from the provided
5589
// data.
5590
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5591
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5592
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5593
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5594

×
5595
        node1, node2, err := buildNodeVertices(
×
5596
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5597
        )
×
5598
        if err != nil {
×
5599
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5600
        }
×
5601

5602
        edge, err := buildEdgeInfoWithBatchData(
×
5603
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5604
        )
×
5605
        if err != nil {
×
5606
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5607
        }
×
5608

5609
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5610
        if err != nil {
×
5611
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5612
                        err)
×
5613
        }
×
5614

5615
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5616
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5617
                channelBatchData,
×
5618
        )
×
5619
        if err != nil {
×
5620
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5621
                        err)
×
5622
        }
×
5623

5624
        // Determine outgoing and incoming policy for this specific node.
5625
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5626
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5627
        outPolicy, inPolicy := p1, p2
×
5628
        if (p1 != nil && p1ToNode == nodeID) ||
×
5629
                (p2 != nil && p2ToNode != nodeID) {
×
5630

×
5631
                outPolicy, inPolicy = p2, p1
×
5632
        }
×
5633

5634
        // Build cached policy.
5635
        var cachedInPolicy *models.CachedEdgePolicy
×
5636
        if inPolicy != nil {
×
5637
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5638
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5639
                cachedInPolicy.ToNodeFeatures = features
×
5640
        }
×
5641

5642
        // Extract inbound fee.
5643
        var inboundFee lnwire.Fee
×
5644
        if outPolicy != nil {
×
5645
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5646
                        inboundFee = fee
×
5647
                })
×
5648
        }
5649

5650
        // Build directed channel.
5651
        directedChannel := &DirectedChannel{
×
5652
                ChannelID:    edge.ChannelID,
×
5653
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5654
                OtherNode:    edge.NodeKey2Bytes,
×
5655
                Capacity:     edge.Capacity,
×
5656
                OutPolicySet: outPolicy != nil,
×
5657
                InPolicy:     cachedInPolicy,
×
5658
                InboundFee:   inboundFee,
×
5659
        }
×
5660

×
5661
        if nodePub == edge.NodeKey2Bytes {
×
5662
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5663
        }
×
5664

5665
        return directedChannel, nil
×
5666
}
5667

5668
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5669
// provided rows. It uses batch loading for channels, policies, and nodes.
5670
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5671
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5672

×
5673
        var (
×
5674
                channelIDs = make([]int64, len(rows))
×
5675
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5676
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5677

×
5678
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5679
                nodeIDSet = make(map[int64]bool)
×
5680

×
5681
                // edges will hold the final channel edges built from the rows.
×
5682
                edges = make([]ChannelEdge, 0, len(rows))
×
5683
        )
×
5684

×
5685
        // Collect all IDs needed for batch loading.
×
5686
        for i, row := range rows {
×
5687
                channelIDs[i] = row.Channel().ID
×
5688

×
5689
                // Collect policy IDs
×
5690
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5691
                if err != nil {
×
5692
                        return nil, fmt.Errorf("unable to extract channel "+
×
5693
                                "policies: %w", err)
×
5694
                }
×
5695
                if dbPol1 != nil {
×
5696
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5697
                }
×
5698
                if dbPol2 != nil {
×
5699
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5700
                }
×
5701

5702
                var (
×
5703
                        node1ID = row.Node1().ID
×
5704
                        node2ID = row.Node2().ID
×
5705
                )
×
5706

×
5707
                // Collect unique node IDs.
×
5708
                if !nodeIDSet[node1ID] {
×
5709
                        nodeIDs = append(nodeIDs, node1ID)
×
5710
                        nodeIDSet[node1ID] = true
×
5711
                }
×
5712

5713
                if !nodeIDSet[node2ID] {
×
5714
                        nodeIDs = append(nodeIDs, node2ID)
×
5715
                        nodeIDSet[node2ID] = true
×
5716
                }
×
5717
        }
5718

5719
        // Batch the data for all the channels and policies.
5720
        channelBatchData, err := batchLoadChannelData(
×
5721
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5722
        )
×
5723
        if err != nil {
×
5724
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5725
                        "policy data: %w", err)
×
5726
        }
×
5727

5728
        // Batch the data for all the nodes.
5729
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5730
        if err != nil {
×
5731
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5732
                        err)
×
5733
        }
×
5734

5735
        // Build all channel edges using batch data.
5736
        for _, row := range rows {
×
5737
                // Build nodes using batch data.
×
5738
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5739
                if err != nil {
×
5740
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5741
                }
×
5742

5743
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5744
                if err != nil {
×
5745
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5746
                }
×
5747

5748
                // Build channel info using batch data.
5749
                channel, err := buildEdgeInfoWithBatchData(
×
5750
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5751
                        node2.PubKeyBytes, channelBatchData,
×
5752
                )
×
5753
                if err != nil {
×
5754
                        return nil, fmt.Errorf("unable to build channel "+
×
5755
                                "info: %w", err)
×
5756
                }
×
5757

5758
                // Extract and build policies using batch data.
5759
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5760
                if err != nil {
×
5761
                        return nil, fmt.Errorf("unable to extract channel "+
×
5762
                                "policies: %w", err)
×
5763
                }
×
5764

5765
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5766
                        dbPol1, dbPol2, channel.ChannelID,
×
5767
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5768
                )
×
5769
                if err != nil {
×
5770
                        return nil, fmt.Errorf("unable to build channel "+
×
5771
                                "policies: %w", err)
×
5772
                }
×
5773

5774
                edges = append(edges, ChannelEdge{
×
5775
                        Info:    channel,
×
5776
                        Policy1: p1,
×
5777
                        Policy2: p2,
×
5778
                        Node1:   node1,
×
5779
                        Node2:   node2,
×
5780
                })
×
5781
        }
5782

5783
        return edges, nil
×
5784
}
5785

5786
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5787
// instances from the provided rows using batch loading for channel data.
5788
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5789
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5790
        []*models.ChannelEdgeInfo, []int64, error) {
×
5791

×
5792
        if len(rows) == 0 {
×
5793
                return nil, nil, nil
×
5794
        }
×
5795

5796
        // Collect all the channel IDs needed for batch loading.
5797
        channelIDs := make([]int64, len(rows))
×
5798
        for i, row := range rows {
×
5799
                channelIDs[i] = row.Channel().ID
×
5800
        }
×
5801

5802
        // Batch load the channel data.
5803
        channelBatchData, err := batchLoadChannelData(
×
5804
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5805
        )
×
5806
        if err != nil {
×
5807
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5808
                        "data: %w", err)
×
5809
        }
×
5810

5811
        // Build all channel edges using batch data.
5812
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5813
        for _, row := range rows {
×
5814
                node1, node2, err := buildNodeVertices(
×
5815
                        row.Node1Pub(), row.Node2Pub(),
×
5816
                )
×
5817
                if err != nil {
×
5818
                        return nil, nil, err
×
5819
                }
×
5820

5821
                // Build channel info using batch data
5822
                info, err := buildEdgeInfoWithBatchData(
×
5823
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5824
                        channelBatchData,
×
5825
                )
×
5826
                if err != nil {
×
5827
                        return nil, nil, err
×
5828
                }
×
5829

5830
                edges = append(edges, info)
×
5831
        }
5832

5833
        return edges, channelIDs, nil
×
5834
}
5835

5836
// handleZombieMarking is a helper function that handles the logic of
5837
// marking a channel as a zombie in the database. It takes into account whether
5838
// we are in strict zombie pruning mode, and adjusts the node public keys
5839
// accordingly based on the last update timestamps of the channel policies.
5840
func handleZombieMarking(ctx context.Context, db SQLQueries,
5841
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5842
        strictZombiePruning bool, scid uint64) error {
×
5843

×
5844
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5845

×
5846
        if strictZombiePruning {
×
5847
                var e1UpdateTime, e2UpdateTime *time.Time
×
5848
                if row.Policy1LastUpdate.Valid {
×
5849
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5850
                        e1UpdateTime = &e1Time
×
5851
                }
×
5852
                if row.Policy2LastUpdate.Valid {
×
5853
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5854
                        e2UpdateTime = &e2Time
×
5855
                }
×
5856

5857
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5858
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5859
                        e2UpdateTime,
×
5860
                )
×
5861
        }
5862

5863
        return db.UpsertZombieChannel(
×
5864
                ctx, sqlc.UpsertZombieChannelParams{
×
5865
                        Version:  int16(lnwire.GossipVersion1),
×
5866
                        Scid:     channelIDToBytes(scid),
×
5867
                        NodeKey1: nodeKey1[:],
×
5868
                        NodeKey2: nodeKey2[:],
×
5869
                },
×
5870
        )
×
5871
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc