• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 20002670438

07 Dec 2025 10:11AM UTC coverage: 65.2% (-0.004%) from 65.204%
20002670438

Pull #10428

github

web-flow
Merge 6829f57c8 into a76f22da9
Pull Request #10428: graphdb: fix potential sql tx exhaustion

7 of 14 new or added lines in 2 files covered. (50.0%)

72 existing lines in 20 files now uncovered.

137638 of 211100 relevant lines covered (65.2%)

20798.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        UpsertSourceNode(ctx context.Context, arg sqlc.UpsertSourceNodeParams) (int64, error)
46
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
47
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
48
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
49
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
50
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
51
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
52
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
53
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
54
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
55
        DeleteNode(ctx context.Context, id int64) error
56

57
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
58
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
59
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
60
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
61

62
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
63
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
64
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
65
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
66

67
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
68
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
69
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
70
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
71
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
72

73
        /*
74
                Source node queries.
75
        */
76
        AddSourceNode(ctx context.Context, nodeID int64) error
77
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
78

79
        /*
80
                Channel queries.
81
        */
82
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
83
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
84
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
85
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
86
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
87
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
88
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
89
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
90
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
91
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
92
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
93
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
94
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
95
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
96
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
97
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
98
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
99
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
100
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
101
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
102
        DeleteChannels(ctx context.Context, ids []int64) error
103

104
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
105
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
106
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
107
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
108

109
        /*
110
                Channel Policy table queries.
111
        */
112
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
113
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
114
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
115

116
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
117
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
118
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
119

120
        /*
121
                Zombie index queries.
122
        */
123
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
124
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
125
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
126
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
127
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
128
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
129

130
        /*
131
                Prune log table queries.
132
        */
133
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
134
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
135
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
136
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
137
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
138

139
        /*
140
                Closed SCID table queries.
141
        */
142
        InsertClosedChannel(ctx context.Context, scid []byte) error
143
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
144
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
145

146
        /*
147
                Migration specific queries.
148

149
                NOTE: these should not be used in code other than migrations.
150
                Once sqldbv2 is in place, these can be removed from this struct
151
                as then migrations will have their own dedicated queries
152
                structs.
153
        */
154
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
155
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
156
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
157
}
158

159
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
160
// database operations.
161
type BatchedSQLQueries interface {
162
        SQLQueries
163
        sqldb.BatchedTx[SQLQueries]
164
}
165

166
// SQLStore is an implementation of the V1Store interface that uses a SQL
167
// database as the backend.
168
type SQLStore struct {
169
        cfg *SQLStoreConfig
170
        db  BatchedSQLQueries
171

172
        // cacheMu guards all caches (rejectCache and chanCache). If
173
        // this mutex will be acquired at the same time as the DB mutex then
174
        // the cacheMu MUST be acquired first to prevent deadlock.
175
        cacheMu     sync.RWMutex
176
        rejectCache *rejectCache
177
        chanCache   *channelCache
178

179
        chanScheduler batch.Scheduler[SQLQueries]
180
        nodeScheduler batch.Scheduler[SQLQueries]
181

182
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
183
        srcNodeMu sync.Mutex
184
}
185

186
// A compile-time assertion to ensure that SQLStore implements the V1Store
187
// interface.
188
var _ V1Store = (*SQLStore)(nil)
189

190
// SQLStoreConfig holds the configuration for the SQLStore.
191
type SQLStoreConfig struct {
192
        // ChainHash is the genesis hash for the chain that all the gossip
193
        // messages in this store are aimed at.
194
        ChainHash chainhash.Hash
195

196
        // QueryConfig holds configuration values for SQL queries.
197
        QueryCfg *sqldb.QueryConfig
198
}
199

200
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
201
// storage backend.
202
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
203
        options ...StoreOptionModifier) (*SQLStore, error) {
×
204

×
205
        opts := DefaultOptions()
×
206
        for _, o := range options {
×
207
                o(opts)
×
208
        }
×
209

210
        if opts.NoMigration {
×
211
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
212
                        "supported for SQL stores")
×
213
        }
×
214

215
        s := &SQLStore{
×
216
                cfg:         cfg,
×
217
                db:          db,
×
218
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
219
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
220
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
221
        }
×
222

×
223
        s.chanScheduler = batch.NewTimeScheduler(
×
224
                db, &s.cacheMu, opts.BatchCommitInterval,
×
225
        )
×
226
        s.nodeScheduler = batch.NewTimeScheduler(
×
227
                db, nil, opts.BatchCommitInterval,
×
228
        )
×
229

×
230
        return s, nil
×
231
}
232

233
// AddNode adds a vertex/node to the graph database. If the node is not
234
// in the database from before, this will add a new, unconnected one to the
235
// graph. If it is present from before, this will update that node's
236
// information.
237
//
238
// NOTE: part of the V1Store interface.
239
func (s *SQLStore) AddNode(ctx context.Context,
240
        node *models.Node, opts ...batch.SchedulerOption) error {
×
241

×
242
        r := &batch.Request[SQLQueries]{
×
243
                Opts: batch.NewSchedulerOptions(opts...),
×
244
                Do: func(queries SQLQueries) error {
×
245
                        _, err := upsertNode(ctx, queries, node)
×
246

×
247
                        // It is possible that two of the same node
×
248
                        // announcements are both being processed in the same
×
249
                        // batch. This may case the UpsertNode conflict to
×
250
                        // be hit since we require at the db layer that the
×
251
                        // new last_update is greater than the existing
×
252
                        // last_update. We need to gracefully handle this here.
×
253
                        if errors.Is(err, sql.ErrNoRows) {
×
254
                                return nil
×
255
                        }
×
256

257
                        return err
×
258
                },
259
        }
260

261
        return s.nodeScheduler.Execute(ctx, r)
×
262
}
263

264
// FetchNode attempts to look up a target node by its identity public
265
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
266
// returned.
267
//
268
// NOTE: part of the V1Store interface.
269
func (s *SQLStore) FetchNode(ctx context.Context,
270
        pubKey route.Vertex) (*models.Node, error) {
×
271

×
272
        var node *models.Node
×
273
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
274
                var err error
×
275
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
276

×
277
                return err
×
278
        }, sqldb.NoOpReset)
×
279
        if err != nil {
×
280
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
281
        }
×
282

283
        return node, nil
×
284
}
285

286
// HasNode determines if the graph has a vertex identified by the
287
// target node identity public key. If the node exists in the database, a
288
// timestamp of when the data for the node was lasted updated is returned along
289
// with a true boolean. Otherwise, an empty time.Time is returned with a false
290
// boolean.
291
//
292
// NOTE: part of the V1Store interface.
293
func (s *SQLStore) HasNode(ctx context.Context,
294
        pubKey [33]byte) (time.Time, bool, error) {
×
295

×
296
        var (
×
297
                exists     bool
×
298
                lastUpdate time.Time
×
299
        )
×
300
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
301
                dbNode, err := db.GetNodeByPubKey(
×
302
                        ctx, sqlc.GetNodeByPubKeyParams{
×
303
                                Version: int16(lnwire.GossipVersion1),
×
304
                                PubKey:  pubKey[:],
×
305
                        },
×
306
                )
×
307
                if errors.Is(err, sql.ErrNoRows) {
×
308
                        return nil
×
309
                } else if err != nil {
×
310
                        return fmt.Errorf("unable to fetch node: %w", err)
×
311
                }
×
312

313
                exists = true
×
314

×
315
                if dbNode.LastUpdate.Valid {
×
316
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
317
                }
×
318

319
                return nil
×
320
        }, sqldb.NoOpReset)
321
        if err != nil {
×
322
                return time.Time{}, false,
×
323
                        fmt.Errorf("unable to fetch node: %w", err)
×
324
        }
×
325

326
        return lastUpdate, exists, nil
×
327
}
328

329
// AddrsForNode returns all known addresses for the target node public key
330
// that the graph DB is aware of. The returned boolean indicates if the
331
// given node is unknown to the graph DB or not.
332
//
333
// NOTE: part of the V1Store interface.
334
func (s *SQLStore) AddrsForNode(ctx context.Context,
335
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
336

×
337
        var (
×
338
                addresses []net.Addr
×
339
                known     bool
×
340
        )
×
341
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
342
                // First, check if the node exists and get its DB ID if it
×
343
                // does.
×
344
                dbID, err := db.GetNodeIDByPubKey(
×
345
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
346
                                Version: int16(lnwire.GossipVersion1),
×
347
                                PubKey:  nodePub.SerializeCompressed(),
×
348
                        },
×
349
                )
×
350
                if errors.Is(err, sql.ErrNoRows) {
×
351
                        return nil
×
352
                }
×
353

354
                known = true
×
355

×
356
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
357
                if err != nil {
×
358
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
359
                                err)
×
360
                }
×
361

362
                return nil
×
363
        }, sqldb.NoOpReset)
364
        if err != nil {
×
365
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
366
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
367
        }
×
368

369
        return known, addresses, nil
×
370
}
371

372
// DeleteNode starts a new database transaction to remove a vertex/node
373
// from the database according to the node's public key.
374
//
375
// NOTE: part of the V1Store interface.
376
func (s *SQLStore) DeleteNode(ctx context.Context,
377
        pubKey route.Vertex) error {
×
378

×
379
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
380
                res, err := db.DeleteNodeByPubKey(
×
381
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
382
                                Version: int16(lnwire.GossipVersion1),
×
383
                                PubKey:  pubKey[:],
×
384
                        },
×
385
                )
×
386
                if err != nil {
×
387
                        return err
×
388
                }
×
389

390
                rows, err := res.RowsAffected()
×
391
                if err != nil {
×
392
                        return err
×
393
                }
×
394

395
                if rows == 0 {
×
396
                        return ErrGraphNodeNotFound
×
397
                } else if rows > 1 {
×
398
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
399
                }
×
400

401
                return err
×
402
        }, sqldb.NoOpReset)
403
        if err != nil {
×
404
                return fmt.Errorf("unable to delete node: %w", err)
×
405
        }
×
406

407
        return nil
×
408
}
409

410
// FetchNodeFeatures returns the features of the given node. If no features are
411
// known for the node, an empty feature vector is returned.
412
//
413
// NOTE: this is part of the graphdb.NodeTraverser interface.
414
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
415
        *lnwire.FeatureVector, error) {
×
416

×
417
        ctx := context.TODO()
×
418

×
419
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
420
}
×
421

422
// DisabledChannelIDs returns the channel ids of disabled channels.
423
// A channel is disabled when two of the associated ChanelEdgePolicies
424
// have their disabled bit on.
425
//
426
// NOTE: part of the V1Store interface.
427
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
428
        var (
×
429
                ctx     = context.TODO()
×
430
                chanIDs []uint64
×
431
        )
×
432
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
433
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
434
                if err != nil {
×
435
                        return fmt.Errorf("unable to fetch disabled "+
×
436
                                "channels: %w", err)
×
437
                }
×
438

439
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
440

×
441
                return nil
×
442
        }, sqldb.NoOpReset)
443
        if err != nil {
×
444
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
445
                        err)
×
446
        }
×
447

448
        return chanIDs, nil
×
449
}
450

451
// LookupAlias attempts to return the alias as advertised by the target node.
452
//
453
// NOTE: part of the V1Store interface.
454
func (s *SQLStore) LookupAlias(ctx context.Context,
455
        pub *btcec.PublicKey) (string, error) {
×
456

×
457
        var alias string
×
458
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
459
                dbNode, err := db.GetNodeByPubKey(
×
460
                        ctx, sqlc.GetNodeByPubKeyParams{
×
461
                                Version: int16(lnwire.GossipVersion1),
×
462
                                PubKey:  pub.SerializeCompressed(),
×
463
                        },
×
464
                )
×
465
                if errors.Is(err, sql.ErrNoRows) {
×
466
                        return ErrNodeAliasNotFound
×
467
                } else if err != nil {
×
468
                        return fmt.Errorf("unable to fetch node: %w", err)
×
469
                }
×
470

471
                if !dbNode.Alias.Valid {
×
472
                        return ErrNodeAliasNotFound
×
473
                }
×
474

475
                alias = dbNode.Alias.String
×
476

×
477
                return nil
×
478
        }, sqldb.NoOpReset)
479
        if err != nil {
×
480
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
481
        }
×
482

483
        return alias, nil
×
484
}
485

486
// SourceNode returns the source node of the graph. The source node is treated
487
// as the center node within a star-graph. This method may be used to kick off
488
// a path finding algorithm in order to explore the reachability of another
489
// node based off the source node.
490
//
491
// NOTE: part of the V1Store interface.
492
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
493
        error) {
×
494

×
495
        var node *models.Node
×
496
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
497
                _, nodePub, err := s.getSourceNode(
×
498
                        ctx, db, lnwire.GossipVersion1,
×
499
                )
×
500
                if err != nil {
×
501
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
502
                                err)
×
503
                }
×
504

505
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
506

×
507
                return err
×
508
        }, sqldb.NoOpReset)
509
        if err != nil {
×
510
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
511
        }
×
512

513
        return node, nil
×
514
}
515

516
// SetSourceNode sets the source node within the graph database. The source
517
// node is to be used as the center of a star-graph within path finding
518
// algorithms.
519
//
520
// NOTE: part of the V1Store interface.
521
func (s *SQLStore) SetSourceNode(ctx context.Context,
522
        node *models.Node) error {
×
523

×
524
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
525
                // For the source node, we use a less strict upsert that allows
×
526
                // updates even when the timestamp hasn't changed. This handles
×
527
                // the race condition where multiple goroutines (e.g.,
×
528
                // setSelfNode, createNewHiddenService, RPC updates) read the
×
529
                // same old timestamp, independently increment it, and try to
×
530
                // write concurrently. We want all parameter changes to persist,
×
531
                // even if timestamps collide.
×
532
                id, err := upsertSourceNode(ctx, db, node)
×
533
                if err != nil {
×
534
                        return fmt.Errorf("unable to upsert source node: %w",
×
535
                                err)
×
536
                }
×
537

538
                // Make sure that if a source node for this version is already
539
                // set, then the ID is the same as the one we are about to set.
540
                dbSourceNodeID, _, err := s.getSourceNode(
×
541
                        ctx, db, lnwire.GossipVersion1,
×
542
                )
×
543
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
544
                        return fmt.Errorf("unable to fetch source node: %w",
×
545
                                err)
×
546
                } else if err == nil {
×
547
                        if dbSourceNodeID != id {
×
548
                                return fmt.Errorf("v1 source node already "+
×
549
                                        "set to a different node: %d vs %d",
×
550
                                        dbSourceNodeID, id)
×
551
                        }
×
552

553
                        return nil
×
554
                }
555

556
                return db.AddSourceNode(ctx, id)
×
557
        }, sqldb.NoOpReset)
558
}
559

560
// NodeUpdatesInHorizon returns all the known lightning node which have an
561
// update timestamp within the passed range. This method can be used by two
562
// nodes to quickly determine if they have the same set of up to date node
563
// announcements.
564
//
565
// NOTE: This is part of the V1Store interface.
566
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
567
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
568

×
569
        cfg := defaultIteratorConfig()
×
570
        for _, opt := range opts {
×
571
                opt(cfg)
×
572
        }
×
573

574
        return func(yield func(*models.Node, error) bool) {
×
575
                var (
×
576
                        ctx            = context.TODO()
×
577
                        lastUpdateTime sql.NullInt64
×
578
                        lastPubKey     = make([]byte, 33)
×
579
                        hasMore        = true
×
580
                )
×
581

×
582
                // Each iteration, we'll read a batch amount of nodes, yield
×
583
                // them, then decide is we have more or not.
×
584
                for hasMore {
×
585
                        var batch []*models.Node
×
586

×
587
                        //nolint:ll
×
588
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
589
                                //nolint:ll
×
590
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
591
                                        StartTime: sqldb.SQLInt64(
×
592
                                                startTime.Unix(),
×
593
                                        ),
×
594
                                        EndTime: sqldb.SQLInt64(
×
595
                                                endTime.Unix(),
×
596
                                        ),
×
597
                                        LastUpdate: lastUpdateTime,
×
598
                                        LastPubKey: lastPubKey,
×
599
                                        OnlyPublic: sql.NullBool{
×
600
                                                Bool:  cfg.iterPublicNodes,
×
601
                                                Valid: true,
×
602
                                        },
×
603
                                        MaxResults: sqldb.SQLInt32(
×
604
                                                cfg.nodeUpdateIterBatchSize,
×
605
                                        ),
×
606
                                }
×
607
                                rows, err := db.GetNodesByLastUpdateRange(
×
608
                                        ctx, params,
×
609
                                )
×
610
                                if err != nil {
×
611
                                        return err
×
612
                                }
×
613

614
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
615

×
616
                                err = forEachNodeInBatch(
×
617
                                        ctx, s.cfg.QueryCfg, db, rows,
×
618
                                        func(_ int64, node *models.Node) error {
×
619
                                                batch = append(batch, node)
×
620

×
621
                                                // Update pagination cursors
×
622
                                                // based on the last processed
×
623
                                                // node.
×
624
                                                lastUpdateTime = sql.NullInt64{
×
625
                                                        Int64: node.LastUpdate.
×
626
                                                                Unix(),
×
627
                                                        Valid: true,
×
628
                                                }
×
629
                                                lastPubKey = node.PubKeyBytes[:]
×
630

×
631
                                                return nil
×
632
                                        },
×
633
                                )
634
                                if err != nil {
×
635
                                        return fmt.Errorf("unable to build "+
×
636
                                                "nodes: %w", err)
×
637
                                }
×
638

639
                                return nil
×
640
                        }, func() {
×
641
                                batch = []*models.Node{}
×
642
                        })
×
643

644
                        if err != nil {
×
645
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
646
                                        "error: %v", err)
×
647

×
648
                                yield(&models.Node{}, err)
×
649

×
650
                                return
×
651
                        }
×
652

653
                        for _, node := range batch {
×
654
                                if !yield(node, nil) {
×
655
                                        return
×
656
                                }
×
657
                        }
658

659
                        // If the batch didn't yield anything, then we're done.
660
                        if len(batch) == 0 {
×
661
                                break
×
662
                        }
663
                }
664
        }
665
}
666

667
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
668
// undirected edge from the two target nodes are created. The information stored
669
// denotes the static attributes of the channel, such as the channelID, the keys
670
// involved in creation of the channel, and the set of features that the channel
671
// supports. The chanPoint and chanID are used to uniquely identify the edge
672
// globally within the database.
673
//
674
// NOTE: part of the V1Store interface.
675
func (s *SQLStore) AddChannelEdge(ctx context.Context,
676
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
677

×
678
        var alreadyExists bool
×
679
        r := &batch.Request[SQLQueries]{
×
680
                Opts: batch.NewSchedulerOptions(opts...),
×
681
                Reset: func() {
×
682
                        alreadyExists = false
×
683
                },
×
684
                Do: func(tx SQLQueries) error {
×
685
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
686

×
687
                        // Make sure that the channel doesn't already exist. We
×
688
                        // do this explicitly instead of relying on catching a
×
689
                        // unique constraint error because relying on SQL to
×
690
                        // throw that error would abort the entire batch of
×
691
                        // transactions.
×
692
                        _, err := tx.GetChannelBySCID(
×
693
                                ctx, sqlc.GetChannelBySCIDParams{
×
694
                                        Scid:    chanIDB,
×
695
                                        Version: int16(lnwire.GossipVersion1),
×
696
                                },
×
697
                        )
×
698
                        if err == nil {
×
699
                                alreadyExists = true
×
700
                                return nil
×
701
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
702
                                return fmt.Errorf("unable to fetch channel: %w",
×
703
                                        err)
×
704
                        }
×
705

706
                        return insertChannel(ctx, tx, edge)
×
707
                },
708
                OnCommit: func(err error) error {
×
709
                        switch {
×
710
                        case err != nil:
×
711
                                return err
×
712
                        case alreadyExists:
×
713
                                return ErrEdgeAlreadyExist
×
714
                        default:
×
715
                                s.rejectCache.remove(edge.ChannelID)
×
716
                                s.chanCache.remove(edge.ChannelID)
×
717
                                return nil
×
718
                        }
719
                },
720
        }
721

722
        return s.chanScheduler.Execute(ctx, r)
×
723
}
724

725
// HighestChanID returns the "highest" known channel ID in the channel graph.
726
// This represents the "newest" channel from the PoV of the chain. This method
727
// can be used by peers to quickly determine if their graphs are in sync.
728
//
729
// NOTE: This is part of the V1Store interface.
730
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
731
        var highestChanID uint64
×
732
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
733
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
734
                if errors.Is(err, sql.ErrNoRows) {
×
735
                        return nil
×
736
                } else if err != nil {
×
737
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
738
                                err)
×
739
                }
×
740

741
                highestChanID = byteOrder.Uint64(chanID)
×
742

×
743
                return nil
×
744
        }, sqldb.NoOpReset)
745
        if err != nil {
×
746
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
747
        }
×
748

749
        return highestChanID, nil
×
750
}
751

752
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
753
// within the database for the referenced channel. The `flags` attribute within
754
// the ChannelEdgePolicy determines which of the directed edges are being
755
// updated. If the flag is 1, then the first node's information is being
756
// updated, otherwise it's the second node's information. The node ordering is
757
// determined by the lexicographical ordering of the identity public keys of the
758
// nodes on either side of the channel.
759
//
760
// NOTE: part of the V1Store interface.
761
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
762
        edge *models.ChannelEdgePolicy,
763
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
764

×
765
        var (
×
766
                isUpdate1    bool
×
767
                edgeNotFound bool
×
768
                from, to     route.Vertex
×
769
        )
×
770

×
771
        r := &batch.Request[SQLQueries]{
×
772
                Opts: batch.NewSchedulerOptions(opts...),
×
773
                Reset: func() {
×
774
                        isUpdate1 = false
×
775
                        edgeNotFound = false
×
776
                },
×
777
                Do: func(tx SQLQueries) error {
×
778
                        var err error
×
779
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
780
                                ctx, tx, edge,
×
781
                        )
×
782
                        // It is possible that two of the same policy
×
783
                        // announcements are both being processed in the same
×
784
                        // batch. This may case the UpsertEdgePolicy conflict to
×
785
                        // be hit since we require at the db layer that the
×
786
                        // new last_update is greater than the existing
×
787
                        // last_update. We need to gracefully handle this here.
×
788
                        if errors.Is(err, sql.ErrNoRows) {
×
789
                                return nil
×
790
                        } else if err != nil {
×
791
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
792
                        }
×
793

794
                        // Silence ErrEdgeNotFound so that the batch can
795
                        // succeed, but propagate the error via local state.
796
                        if errors.Is(err, ErrEdgeNotFound) {
×
797
                                edgeNotFound = true
×
798
                                return nil
×
799
                        }
×
800

801
                        return err
×
802
                },
803
                OnCommit: func(err error) error {
×
804
                        switch {
×
805
                        case err != nil:
×
806
                                return err
×
807
                        case edgeNotFound:
×
808
                                return ErrEdgeNotFound
×
809
                        default:
×
810
                                s.updateEdgeCache(edge, isUpdate1)
×
811
                                return nil
×
812
                        }
813
                },
814
        }
815

816
        err := s.chanScheduler.Execute(ctx, r)
×
817

×
818
        return from, to, err
×
819
}
820

821
// updateEdgeCache updates our reject and channel caches with the new
822
// edge policy information.
823
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
824
        isUpdate1 bool) {
×
825

×
826
        // If an entry for this channel is found in reject cache, we'll modify
×
827
        // the entry with the updated timestamp for the direction that was just
×
828
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
829
        // during the next query for this edge.
×
830
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
831
                if isUpdate1 {
×
832
                        entry.upd1Time = e.LastUpdate.Unix()
×
833
                } else {
×
834
                        entry.upd2Time = e.LastUpdate.Unix()
×
835
                }
×
836
                s.rejectCache.insert(e.ChannelID, entry)
×
837
        }
838

839
        // If an entry for this channel is found in channel cache, we'll modify
840
        // the entry with the updated policy for the direction that was just
841
        // written. If the edge doesn't exist, we'll defer loading the info and
842
        // policies and lazily read from disk during the next query.
843
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
844
                if isUpdate1 {
×
845
                        channel.Policy1 = e
×
846
                } else {
×
847
                        channel.Policy2 = e
×
848
                }
×
849
                s.chanCache.insert(e.ChannelID, channel)
×
850
        }
851
}
852

853
// ForEachSourceNodeChannel iterates through all channels of the source node,
854
// executing the passed callback on each. The call-back is provided with the
855
// channel's outpoint, whether we have a policy for the channel and the channel
856
// peer's node information.
857
//
858
// NOTE: part of the V1Store interface.
859
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
860
        cb func(chanPoint wire.OutPoint, havePolicy bool,
861
                otherNode *models.Node) error, reset func()) error {
×
862

×
863
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
864
                nodeID, nodePub, err := s.getSourceNode(
×
865
                        ctx, db, lnwire.GossipVersion1,
×
866
                )
×
867
                if err != nil {
×
868
                        return fmt.Errorf("unable to fetch source node: %w",
×
869
                                err)
×
870
                }
×
871

872
                return forEachNodeChannel(
×
873
                        ctx, db, s.cfg, nodeID,
×
874
                        func(info *models.ChannelEdgeInfo,
×
875
                                outPolicy *models.ChannelEdgePolicy,
×
876
                                _ *models.ChannelEdgePolicy) error {
×
877

×
878
                                // Fetch the other node.
×
879
                                var (
×
880
                                        otherNodePub [33]byte
×
881
                                        node1        = info.NodeKey1Bytes
×
882
                                        node2        = info.NodeKey2Bytes
×
883
                                )
×
884
                                switch {
×
885
                                case bytes.Equal(node1[:], nodePub[:]):
×
886
                                        otherNodePub = node2
×
887
                                case bytes.Equal(node2[:], nodePub[:]):
×
888
                                        otherNodePub = node1
×
889
                                default:
×
890
                                        return fmt.Errorf("node not " +
×
891
                                                "participating in this channel")
×
892
                                }
893

894
                                _, otherNode, err := getNodeByPubKey(
×
895
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
896
                                )
×
897
                                if err != nil {
×
898
                                        return fmt.Errorf("unable to fetch "+
×
899
                                                "other node(%x): %w",
×
900
                                                otherNodePub, err)
×
901
                                }
×
902

903
                                return cb(
×
904
                                        info.ChannelPoint, outPolicy != nil,
×
905
                                        otherNode,
×
906
                                )
×
907
                        },
908
                )
909
        }, reset)
910
}
911

912
// ForEachNode iterates through all the stored vertices/nodes in the graph,
913
// executing the passed callback with each node encountered. If the callback
914
// returns an error, then the transaction is aborted and the iteration stops
915
// early.
916
//
917
// NOTE: part of the V1Store interface.
918
func (s *SQLStore) ForEachNode(ctx context.Context,
919
        cb func(node *models.Node) error, reset func()) error {
×
920

×
921
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
922
                return forEachNodePaginated(
×
923
                        ctx, s.cfg.QueryCfg, db,
×
924
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
925
                                node *models.Node) error {
×
926

×
927
                                return cb(node)
×
928
                        },
×
929
                )
930
        }, reset)
931
}
932

933
// ForEachNodeDirectedChannel iterates through all channels of a given node,
934
// executing the passed callback on the directed edge representing the channel
935
// and its incoming policy. If the callback returns an error, then the iteration
936
// is halted with the error propagated back up to the caller.
937
//
938
// Unknown policies are passed into the callback as nil values.
939
//
940
// NOTE: this is part of the graphdb.NodeTraverser interface.
941
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
942
        cb func(channel *DirectedChannel) error, reset func()) error {
×
943

×
944
        var ctx = context.TODO()
×
945

×
946
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
947
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
948
        }, reset)
×
949
}
950

951
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
952
// graph, executing the passed callback with each node encountered. If the
953
// callback returns an error, then the transaction is aborted and the iteration
954
// stops early.
955
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
956
        cb func(route.Vertex, *lnwire.FeatureVector) error,
957
        reset func()) error {
×
958

×
959
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
960
                return forEachNodeCacheable(
×
961
                        ctx, s.cfg.QueryCfg, db,
×
962
                        func(_ int64, nodePub route.Vertex,
×
963
                                features *lnwire.FeatureVector) error {
×
964

×
965
                                return cb(nodePub, features)
×
966
                        },
×
967
                )
968
        }, reset)
969
        if err != nil {
×
970
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
971
        }
×
972

973
        return nil
×
974
}
975

976
// ForEachNodeChannel iterates through all channels of the given node,
977
// executing the passed callback with an edge info structure and the policies
978
// of each end of the channel. The first edge policy is the outgoing edge *to*
979
// the connecting node, while the second is the incoming edge *from* the
980
// connecting node. If the callback returns an error, then the iteration is
981
// halted with the error propagated back up to the caller.
982
//
983
// Unknown policies are passed into the callback as nil values.
984
//
985
// NOTE: part of the V1Store interface.
986
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
987
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
988
                *models.ChannelEdgePolicy) error, reset func()) error {
×
989

×
990
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
991
                dbNode, err := db.GetNodeByPubKey(
×
992
                        ctx, sqlc.GetNodeByPubKeyParams{
×
993
                                Version: int16(lnwire.GossipVersion1),
×
994
                                PubKey:  nodePub[:],
×
995
                        },
×
996
                )
×
997
                if errors.Is(err, sql.ErrNoRows) {
×
998
                        return nil
×
999
                } else if err != nil {
×
1000
                        return fmt.Errorf("unable to fetch node: %w", err)
×
1001
                }
×
1002

1003
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
1004
        }, reset)
1005
}
1006

1007
// extractMaxUpdateTime returns the maximum of the two policy update times.
1008
// This is used for pagination cursor tracking.
1009
func extractMaxUpdateTime(
1010
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1011

×
1012
        switch {
×
1013
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1014
                return max(row.Policy1LastUpdate.Int64,
×
1015
                        row.Policy2LastUpdate.Int64)
×
1016
        case row.Policy1LastUpdate.Valid:
×
1017
                return row.Policy1LastUpdate.Int64
×
1018
        case row.Policy2LastUpdate.Valid:
×
1019
                return row.Policy2LastUpdate.Int64
×
1020
        default:
×
1021
                return 0
×
1022
        }
1023
}
1024

1025
// buildChannelFromRow constructs a ChannelEdge from a database row.
1026
// This includes building the nodes, channel info, and policies.
1027
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1028
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1029

×
1030
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1031
        if err != nil {
×
1032
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1033
                        err)
×
1034
        }
×
1035

1036
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1037
        if err != nil {
×
1038
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1039
                        err)
×
1040
        }
×
1041

1042
        channel, err := getAndBuildEdgeInfo(
×
1043
                ctx, s.cfg, db,
×
1044
                row.GraphChannel, node1.PubKeyBytes,
×
1045
                node2.PubKeyBytes,
×
1046
        )
×
1047
        if err != nil {
×
1048
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1049
                        "channel info: %w", err)
×
1050
        }
×
1051

1052
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1053
        if err != nil {
×
1054
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1055
                        "channel policies: %w", err)
×
1056
        }
×
1057

1058
        p1, p2, err := getAndBuildChanPolicies(
×
1059
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1060
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1061
        )
×
1062
        if err != nil {
×
1063
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1064
                        "channel policies: %w", err)
×
1065
        }
×
1066

1067
        return ChannelEdge{
×
1068
                Info:    channel,
×
1069
                Policy1: p1,
×
1070
                Policy2: p2,
×
1071
                Node1:   node1,
×
1072
                Node2:   node2,
×
1073
        }, nil
×
1074
}
1075

1076
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1077
// This method acquires the cache lock only once for the entire batch.
1078
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1079
        if len(edgesToCache) == 0 {
×
1080
                return
×
1081
        }
×
1082

1083
        s.cacheMu.Lock()
×
1084
        defer s.cacheMu.Unlock()
×
1085

×
1086
        for chanID, edge := range edgesToCache {
×
1087
                s.chanCache.insert(chanID, edge)
×
1088
        }
×
1089
}
1090

1091
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1092
// one edge that has an update timestamp within the specified horizon.
1093
//
1094
// Iterator Lifecycle:
1095
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1096
// 2. Query batch of channels with policies in time range
1097
// 3. For each channel: check if seen, check cache, or build from DB
1098
// 4. Yield channels to caller
1099
// 5. Update cache after successful batch
1100
// 6. Repeat with updated pagination cursor until no more results
1101
//
1102
// NOTE: This is part of the V1Store interface.
1103
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1104
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1105

×
1106
        // Apply options.
×
1107
        cfg := defaultIteratorConfig()
×
1108
        for _, opt := range opts {
×
1109
                opt(cfg)
×
1110
        }
×
1111

1112
        return func(yield func(ChannelEdge, error) bool) {
×
1113
                var (
×
1114
                        ctx            = context.TODO()
×
1115
                        edgesSeen      = make(map[uint64]struct{})
×
1116
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1117
                        hits           int
×
1118
                        total          int
×
1119
                        lastUpdateTime sql.NullInt64
×
1120
                        lastID         sql.NullInt64
×
1121
                        hasMore        = true
×
1122
                )
×
1123

×
1124
                // Each iteration, we'll read a batch amount of channel updates
×
1125
                // (consulting the cache along the way), yield them, then loop
×
1126
                // back to decide if we have any more updates to read out.
×
1127
                for hasMore {
×
1128
                        var batch []ChannelEdge
×
1129

×
NEW
1130
                        // Acquire read lock before starting transaction to ensure
×
NEW
1131
                        // consistent lock ordering (cacheMu -> DB) and prevent
×
NEW
1132
                        // deadlock with write operations.
×
NEW
1133
                        s.cacheMu.RLock()
×
NEW
1134

×
1135
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1136
                                func(db SQLQueries) error {
×
1137
                                        //nolint:ll
×
1138
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1139
                                                Version: int16(lnwire.GossipVersion1),
×
1140
                                                StartTime: sqldb.SQLInt64(
×
1141
                                                        startTime.Unix(),
×
1142
                                                ),
×
1143
                                                EndTime: sqldb.SQLInt64(
×
1144
                                                        endTime.Unix(),
×
1145
                                                ),
×
1146
                                                LastUpdateTime: lastUpdateTime,
×
1147
                                                LastID:         lastID,
×
1148
                                                MaxResults: sql.NullInt32{
×
1149
                                                        Int32: int32(
×
1150
                                                                cfg.chanUpdateIterBatchSize,
×
1151
                                                        ),
×
1152
                                                        Valid: true,
×
1153
                                                },
×
1154
                                        }
×
1155
                                        //nolint:ll
×
1156
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1157
                                                ctx, params,
×
1158
                                        )
×
1159
                                        if err != nil {
×
1160
                                                return err
×
1161
                                        }
×
1162

1163
                                        //nolint:ll
1164
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1165

×
1166
                                        //nolint:ll
×
1167
                                        for _, row := range rows {
×
1168
                                                lastUpdateTime = sql.NullInt64{
×
1169
                                                        Int64: extractMaxUpdateTime(row),
×
1170
                                                        Valid: true,
×
1171
                                                }
×
1172
                                                lastID = sql.NullInt64{
×
1173
                                                        Int64: row.GraphChannel.ID,
×
1174
                                                        Valid: true,
×
1175
                                                }
×
1176

×
1177
                                                // Skip if we've already
×
1178
                                                // processed this channel.
×
1179
                                                chanIDInt := byteOrder.Uint64(
×
1180
                                                        row.GraphChannel.Scid,
×
1181
                                                )
×
1182
                                                _, ok := edgesSeen[chanIDInt]
×
1183
                                                if ok {
×
1184
                                                        continue
×
1185
                                                }
1186

1187
                                                // Check cache (we already hold RLock).
1188
                                                channel, ok := s.chanCache.get(
×
1189
                                                        chanIDInt,
×
1190
                                                )
×
1191
                                                if ok {
×
1192
                                                        hits++
×
1193
                                                        total++
×
1194
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1195
                                                        batch = append(batch, channel)
×
1196

×
1197
                                                        continue
×
1198
                                                }
1199

1200
                                                chanEdge, err := s.buildChannelFromRow(
×
1201
                                                        ctx, db, row,
×
1202
                                                )
×
1203
                                                if err != nil {
×
1204
                                                        return err
×
1205
                                                }
×
1206

1207
                                                edgesSeen[chanIDInt] = struct{}{}
×
1208
                                                edgesToCache[chanIDInt] = chanEdge
×
1209

×
1210
                                                batch = append(batch, chanEdge)
×
1211

×
1212
                                                total++
×
1213
                                        }
1214

1215
                                        return nil
×
1216
                                }, func() {
×
1217
                                        batch = nil
×
1218
                                        edgesSeen = make(map[uint64]struct{})
×
1219
                                        edgesToCache = make(
×
1220
                                                map[uint64]ChannelEdge,
×
1221
                                        )
×
1222
                                })
×
1223

1224
                        // Release read lock after transaction completes.
NEW
1225
                        s.cacheMu.RUnlock()
×
NEW
1226

×
1227
                        if err != nil {
×
1228
                                log.Errorf("ChanUpdatesInHorizon "+
×
1229
                                        "batch error: %v", err)
×
1230

×
1231
                                yield(ChannelEdge{}, err)
×
1232

×
1233
                                return
×
1234
                        }
×
1235

1236
                        for _, edge := range batch {
×
1237
                                if !yield(edge, nil) {
×
1238
                                        return
×
1239
                                }
×
1240
                        }
1241

1242
                        // Update cache after successful batch yield, setting
1243
                        // the cache lock only once for the entire batch.
1244
                        s.updateChanCacheBatch(edgesToCache)
×
1245
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1246

×
1247
                        // If the batch didn't yield anything, then we're done.
×
1248
                        if len(batch) == 0 {
×
1249
                                break
×
1250
                        }
1251
                }
1252

1253
                if total > 0 {
×
1254
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1255
                                "%.2f (%d/%d)",
×
1256
                                float64(hits)*100/float64(total), hits, total)
×
1257
                } else {
×
1258
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1259
                                "in horizon (%s, %s)", startTime, endTime)
×
1260
                }
×
1261
        }
1262
}
1263

1264
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1265
// data to the call-back. If withAddrs is true, then the call-back will also be
1266
// provided with the addresses associated with the node. The address retrieval
1267
// result in an additional round-trip to the database, so it should only be used
1268
// if the addresses are actually needed.
1269
//
1270
// NOTE: part of the V1Store interface.
1271
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1272
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1273
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1274

×
1275
        type nodeCachedBatchData struct {
×
1276
                features      map[int64][]int
×
1277
                addrs         map[int64][]nodeAddress
×
1278
                chanBatchData *batchChannelData
×
1279
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1280
        }
×
1281

×
1282
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1283
                // pageQueryFunc is used to query the next page of nodes.
×
1284
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1285
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1286

×
1287
                        return db.ListNodeIDsAndPubKeys(
×
1288
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1289
                                        Version: int16(lnwire.GossipVersion1),
×
1290
                                        ID:      lastID,
×
1291
                                        Limit:   limit,
×
1292
                                },
×
1293
                        )
×
1294
                }
×
1295

1296
                // batchDataFunc is then used to batch load the data required
1297
                // for each page of nodes.
1298
                batchDataFunc := func(ctx context.Context,
×
1299
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1300

×
1301
                        // Batch load node features.
×
1302
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1303
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1304
                        )
×
1305
                        if err != nil {
×
1306
                                return nil, fmt.Errorf("unable to batch load "+
×
1307
                                        "node features: %w", err)
×
1308
                        }
×
1309

1310
                        // Maybe fetch the node's addresses if requested.
1311
                        var nodeAddrs map[int64][]nodeAddress
×
1312
                        if withAddrs {
×
1313
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1314
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1315
                                )
×
1316
                                if err != nil {
×
1317
                                        return nil, fmt.Errorf("unable to "+
×
1318
                                                "batch load node "+
×
1319
                                                "addresses: %w", err)
×
1320
                                }
×
1321
                        }
1322

1323
                        // Batch load ALL unique channels for ALL nodes in this
1324
                        // page.
1325
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1326
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1327
                                        Version:  int16(lnwire.GossipVersion1),
×
1328
                                        Node1Ids: nodeIDs,
×
1329
                                        Node2Ids: nodeIDs,
×
1330
                                },
×
1331
                        )
×
1332
                        if err != nil {
×
1333
                                return nil, fmt.Errorf("unable to batch "+
×
1334
                                        "fetch channels for nodes: %w", err)
×
1335
                        }
×
1336

1337
                        // Deduplicate channels and collect IDs.
1338
                        var (
×
1339
                                allChannelIDs []int64
×
1340
                                allPolicyIDs  []int64
×
1341
                        )
×
1342
                        uniqueChannels := make(
×
1343
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1344
                        )
×
1345

×
1346
                        for _, channel := range allChannels {
×
1347
                                channelID := channel.GraphChannel.ID
×
1348

×
1349
                                // Only process each unique channel once.
×
1350
                                _, exists := uniqueChannels[channelID]
×
1351
                                if exists {
×
1352
                                        continue
×
1353
                                }
1354

1355
                                uniqueChannels[channelID] = channel
×
1356
                                allChannelIDs = append(allChannelIDs, channelID)
×
1357

×
1358
                                if channel.Policy1ID.Valid {
×
1359
                                        allPolicyIDs = append(
×
1360
                                                allPolicyIDs,
×
1361
                                                channel.Policy1ID.Int64,
×
1362
                                        )
×
1363
                                }
×
1364
                                if channel.Policy2ID.Valid {
×
1365
                                        allPolicyIDs = append(
×
1366
                                                allPolicyIDs,
×
1367
                                                channel.Policy2ID.Int64,
×
1368
                                        )
×
1369
                                }
×
1370
                        }
1371

1372
                        // Batch load channel data for all unique channels.
1373
                        channelBatchData, err := batchLoadChannelData(
×
1374
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1375
                                allPolicyIDs,
×
1376
                        )
×
1377
                        if err != nil {
×
1378
                                return nil, fmt.Errorf("unable to batch "+
×
1379
                                        "load channel data: %w", err)
×
1380
                        }
×
1381

1382
                        // Create map of node ID to channels that involve this
1383
                        // node.
1384
                        nodeIDSet := make(map[int64]bool)
×
1385
                        for _, nodeID := range nodeIDs {
×
1386
                                nodeIDSet[nodeID] = true
×
1387
                        }
×
1388

1389
                        nodeChannelMap := make(
×
1390
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1391
                        )
×
1392
                        for _, channel := range uniqueChannels {
×
1393
                                // Add channel to both nodes if they're in our
×
1394
                                // current page.
×
1395
                                node1 := channel.GraphChannel.NodeID1
×
1396
                                if nodeIDSet[node1] {
×
1397
                                        nodeChannelMap[node1] = append(
×
1398
                                                nodeChannelMap[node1], channel,
×
1399
                                        )
×
1400
                                }
×
1401
                                node2 := channel.GraphChannel.NodeID2
×
1402
                                if nodeIDSet[node2] {
×
1403
                                        nodeChannelMap[node2] = append(
×
1404
                                                nodeChannelMap[node2], channel,
×
1405
                                        )
×
1406
                                }
×
1407
                        }
1408

1409
                        return &nodeCachedBatchData{
×
1410
                                features:      nodeFeatures,
×
1411
                                addrs:         nodeAddrs,
×
1412
                                chanBatchData: channelBatchData,
×
1413
                                chanMap:       nodeChannelMap,
×
1414
                        }, nil
×
1415
                }
1416

1417
                // processItem is used to process each node in the current page.
1418
                processItem := func(ctx context.Context,
×
1419
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1420
                        batchData *nodeCachedBatchData) error {
×
1421

×
1422
                        // Build feature vector for this node.
×
1423
                        fv := lnwire.EmptyFeatureVector()
×
1424
                        features, exists := batchData.features[nodeData.ID]
×
1425
                        if exists {
×
1426
                                for _, bit := range features {
×
1427
                                        fv.Set(lnwire.FeatureBit(bit))
×
1428
                                }
×
1429
                        }
1430

1431
                        var nodePub route.Vertex
×
1432
                        copy(nodePub[:], nodeData.PubKey)
×
1433

×
1434
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1435

×
1436
                        toNodeCallback := func() route.Vertex {
×
1437
                                return nodePub
×
1438
                        }
×
1439

1440
                        // Build cached channels map for this node.
1441
                        channels := make(map[uint64]*DirectedChannel)
×
1442
                        for _, channelRow := range nodeChannels {
×
1443
                                directedChan, err := buildDirectedChannel(
×
1444
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1445
                                        channelRow, batchData.chanBatchData, fv,
×
1446
                                        toNodeCallback,
×
1447
                                )
×
1448
                                if err != nil {
×
1449
                                        return err
×
1450
                                }
×
1451

1452
                                channels[directedChan.ChannelID] = directedChan
×
1453
                        }
1454

1455
                        addrs, err := buildNodeAddresses(
×
1456
                                batchData.addrs[nodeData.ID],
×
1457
                        )
×
1458
                        if err != nil {
×
1459
                                return fmt.Errorf("unable to build node "+
×
1460
                                        "addresses: %w", err)
×
1461
                        }
×
1462

1463
                        return cb(ctx, nodePub, addrs, channels)
×
1464
                }
1465

1466
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1467
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1468
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1469
                                return node.ID
×
1470
                        },
×
1471
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1472
                                error) {
×
1473

×
1474
                                return node.ID, nil
×
1475
                        },
×
1476
                        batchDataFunc, processItem,
1477
                )
1478
        }, reset)
1479
}
1480

1481
// ForEachChannelCacheable iterates through all the channel edges stored
1482
// within the graph and invokes the passed callback for each edge. The
1483
// callback takes two edges as since this is a directed graph, both the
1484
// in/out edges are visited. If the callback returns an error, then the
1485
// transaction is aborted and the iteration stops early.
1486
//
1487
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1488
// pointer for that particular channel edge routing policy will be
1489
// passed into the callback.
1490
//
1491
// NOTE: this method is like ForEachChannel but fetches only the data
1492
// required for the graph cache.
1493
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1494
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1495
        reset func()) error {
×
1496

×
1497
        ctx := context.TODO()
×
1498

×
1499
        handleChannel := func(_ context.Context,
×
1500
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1501

×
1502
                node1, node2, err := buildNodeVertices(
×
1503
                        row.Node1Pubkey, row.Node2Pubkey,
×
1504
                )
×
1505
                if err != nil {
×
1506
                        return err
×
1507
                }
×
1508

1509
                edge := buildCacheableChannelInfo(
×
1510
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1511
                )
×
1512

×
1513
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1514
                if err != nil {
×
1515
                        return err
×
1516
                }
×
1517

1518
                pol1, pol2, err := buildCachedChanPolicies(
×
1519
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1520
                )
×
1521
                if err != nil {
×
1522
                        return err
×
1523
                }
×
1524

1525
                return cb(edge, pol1, pol2)
×
1526
        }
1527

1528
        extractCursor := func(
×
1529
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1530

×
1531
                return row.ID
×
1532
        }
×
1533

1534
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1535
                //nolint:ll
×
1536
                queryFunc := func(ctx context.Context, lastID int64,
×
1537
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1538
                        error) {
×
1539

×
1540
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1541
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1542
                                        Version: int16(lnwire.GossipVersion1),
×
1543
                                        ID:      lastID,
×
1544
                                        Limit:   limit,
×
1545
                                },
×
1546
                        )
×
1547
                }
×
1548

1549
                return sqldb.ExecutePaginatedQuery(
×
1550
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1551
                        extractCursor, handleChannel,
×
1552
                )
×
1553
        }, reset)
1554
}
1555

1556
// ForEachChannel iterates through all the channel edges stored within the
1557
// graph and invokes the passed callback for each edge. The callback takes two
1558
// edges as since this is a directed graph, both the in/out edges are visited.
1559
// If the callback returns an error, then the transaction is aborted and the
1560
// iteration stops early.
1561
//
1562
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1563
// for that particular channel edge routing policy will be passed into the
1564
// callback.
1565
//
1566
// NOTE: part of the V1Store interface.
1567
func (s *SQLStore) ForEachChannel(ctx context.Context,
1568
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1569
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1570

×
1571
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1572
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1573
        }, reset)
×
1574
}
1575

1576
// FilterChannelRange returns the channel ID's of all known channels which were
1577
// mined in a block height within the passed range. The channel IDs are grouped
1578
// by their common block height. This method can be used to quickly share with a
1579
// peer the set of channels we know of within a particular range to catch them
1580
// up after a period of time offline. If withTimestamps is true then the
1581
// timestamp info of the latest received channel update messages of the channel
1582
// will be included in the response.
1583
//
1584
// NOTE: This is part of the V1Store interface.
1585
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1586
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1587

×
1588
        var (
×
1589
                ctx       = context.TODO()
×
1590
                startSCID = &lnwire.ShortChannelID{
×
1591
                        BlockHeight: startHeight,
×
1592
                }
×
1593
                endSCID = lnwire.ShortChannelID{
×
1594
                        BlockHeight: endHeight,
×
1595
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1596
                        TxPosition:  math.MaxUint16,
×
1597
                }
×
1598
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1599
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1600
        )
×
1601

×
1602
        // 1) get all channels where channelID is between start and end chan ID.
×
1603
        // 2) skip if not public (ie, no channel_proof)
×
1604
        // 3) collect that channel.
×
1605
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1606
        //    and add those timestamps to the collected channel.
×
1607
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1608
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1609
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1610
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1611
                                StartScid: chanIDStart,
×
1612
                                EndScid:   chanIDEnd,
×
1613
                        },
×
1614
                )
×
1615
                if err != nil {
×
1616
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1617
                                err)
×
1618
                }
×
1619

1620
                for _, dbChan := range dbChans {
×
1621
                        cid := lnwire.NewShortChanIDFromInt(
×
1622
                                byteOrder.Uint64(dbChan.Scid),
×
1623
                        )
×
1624
                        chanInfo := NewChannelUpdateInfo(
×
1625
                                cid, time.Time{}, time.Time{},
×
1626
                        )
×
1627

×
1628
                        if !withTimestamps {
×
1629
                                channelsPerBlock[cid.BlockHeight] = append(
×
1630
                                        channelsPerBlock[cid.BlockHeight],
×
1631
                                        chanInfo,
×
1632
                                )
×
1633

×
1634
                                continue
×
1635
                        }
1636

1637
                        //nolint:ll
1638
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1639
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1640
                                        Version:   int16(lnwire.GossipVersion1),
×
1641
                                        ChannelID: dbChan.ID,
×
1642
                                        NodeID:    dbChan.NodeID1,
×
1643
                                },
×
1644
                        )
×
1645
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1646
                                return fmt.Errorf("unable to fetch node1 "+
×
1647
                                        "policy: %w", err)
×
1648
                        } else if err == nil {
×
1649
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1650
                                        node1Policy.LastUpdate.Int64, 0,
×
1651
                                )
×
1652
                        }
×
1653

1654
                        //nolint:ll
1655
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1656
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1657
                                        Version:   int16(lnwire.GossipVersion1),
×
1658
                                        ChannelID: dbChan.ID,
×
1659
                                        NodeID:    dbChan.NodeID2,
×
1660
                                },
×
1661
                        )
×
1662
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1663
                                return fmt.Errorf("unable to fetch node2 "+
×
1664
                                        "policy: %w", err)
×
1665
                        } else if err == nil {
×
1666
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1667
                                        node2Policy.LastUpdate.Int64, 0,
×
1668
                                )
×
1669
                        }
×
1670

1671
                        channelsPerBlock[cid.BlockHeight] = append(
×
1672
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1673
                        )
×
1674
                }
1675

1676
                return nil
×
1677
        }, func() {
×
1678
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1679
        })
×
1680
        if err != nil {
×
1681
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1682
        }
×
1683

1684
        if len(channelsPerBlock) == 0 {
×
1685
                return nil, nil
×
1686
        }
×
1687

1688
        // Return the channel ranges in ascending block height order.
1689
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1690
        slices.Sort(blocks)
×
1691

×
1692
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1693
                return BlockChannelRange{
×
1694
                        Height:   block,
×
1695
                        Channels: channelsPerBlock[block],
×
1696
                }
×
1697
        }), nil
×
1698
}
1699

1700
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1701
// zombie. This method is used on an ad-hoc basis, when channels need to be
1702
// marked as zombies outside the normal pruning cycle.
1703
//
1704
// NOTE: part of the V1Store interface.
1705
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1706
        pubKey1, pubKey2 [33]byte) error {
×
1707

×
1708
        ctx := context.TODO()
×
1709

×
1710
        s.cacheMu.Lock()
×
1711
        defer s.cacheMu.Unlock()
×
1712

×
1713
        chanIDB := channelIDToBytes(chanID)
×
1714

×
1715
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1716
                return db.UpsertZombieChannel(
×
1717
                        ctx, sqlc.UpsertZombieChannelParams{
×
1718
                                Version:  int16(lnwire.GossipVersion1),
×
1719
                                Scid:     chanIDB,
×
1720
                                NodeKey1: pubKey1[:],
×
1721
                                NodeKey2: pubKey2[:],
×
1722
                        },
×
1723
                )
×
1724
        }, sqldb.NoOpReset)
×
1725
        if err != nil {
×
1726
                return fmt.Errorf("unable to upsert zombie channel "+
×
1727
                        "(channel_id=%d): %w", chanID, err)
×
1728
        }
×
1729

1730
        s.rejectCache.remove(chanID)
×
1731
        s.chanCache.remove(chanID)
×
1732

×
1733
        return nil
×
1734
}
1735

1736
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1737
//
1738
// NOTE: part of the V1Store interface.
1739
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1740
        s.cacheMu.Lock()
×
1741
        defer s.cacheMu.Unlock()
×
1742

×
1743
        var (
×
1744
                ctx     = context.TODO()
×
1745
                chanIDB = channelIDToBytes(chanID)
×
1746
        )
×
1747

×
1748
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1749
                res, err := db.DeleteZombieChannel(
×
1750
                        ctx, sqlc.DeleteZombieChannelParams{
×
1751
                                Scid:    chanIDB,
×
1752
                                Version: int16(lnwire.GossipVersion1),
×
1753
                        },
×
1754
                )
×
1755
                if err != nil {
×
1756
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1757
                                err)
×
1758
                }
×
1759

1760
                rows, err := res.RowsAffected()
×
1761
                if err != nil {
×
1762
                        return err
×
1763
                }
×
1764

1765
                if rows == 0 {
×
1766
                        return ErrZombieEdgeNotFound
×
1767
                } else if rows > 1 {
×
1768
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1769
                                "expected 1", rows)
×
1770
                }
×
1771

1772
                return nil
×
1773
        }, sqldb.NoOpReset)
1774
        if err != nil {
×
1775
                return fmt.Errorf("unable to mark edge live "+
×
1776
                        "(channel_id=%d): %w", chanID, err)
×
1777
        }
×
1778

1779
        s.rejectCache.remove(chanID)
×
1780
        s.chanCache.remove(chanID)
×
1781

×
1782
        return err
×
1783
}
1784

1785
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1786
// zombie, then the two node public keys corresponding to this edge are also
1787
// returned.
1788
//
1789
// NOTE: part of the V1Store interface.
1790
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1791
        error) {
×
1792

×
1793
        var (
×
1794
                ctx              = context.TODO()
×
1795
                isZombie         bool
×
1796
                pubKey1, pubKey2 route.Vertex
×
1797
                chanIDB          = channelIDToBytes(chanID)
×
1798
        )
×
1799

×
1800
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1801
                zombie, err := db.GetZombieChannel(
×
1802
                        ctx, sqlc.GetZombieChannelParams{
×
1803
                                Scid:    chanIDB,
×
1804
                                Version: int16(lnwire.GossipVersion1),
×
1805
                        },
×
1806
                )
×
1807
                if errors.Is(err, sql.ErrNoRows) {
×
1808
                        return nil
×
1809
                }
×
1810
                if err != nil {
×
1811
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1812
                                err)
×
1813
                }
×
1814

1815
                copy(pubKey1[:], zombie.NodeKey1)
×
1816
                copy(pubKey2[:], zombie.NodeKey2)
×
1817
                isZombie = true
×
1818

×
1819
                return nil
×
1820
        }, sqldb.NoOpReset)
1821
        if err != nil {
×
1822
                return false, route.Vertex{}, route.Vertex{},
×
1823
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1824
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1825
        }
×
1826

1827
        return isZombie, pubKey1, pubKey2, nil
×
1828
}
1829

1830
// NumZombies returns the current number of zombie channels in the graph.
1831
//
1832
// NOTE: part of the V1Store interface.
1833
func (s *SQLStore) NumZombies() (uint64, error) {
×
1834
        var (
×
1835
                ctx        = context.TODO()
×
1836
                numZombies uint64
×
1837
        )
×
1838
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1839
                count, err := db.CountZombieChannels(
×
1840
                        ctx, int16(lnwire.GossipVersion1),
×
1841
                )
×
1842
                if err != nil {
×
1843
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1844
                                err)
×
1845
                }
×
1846

1847
                numZombies = uint64(count)
×
1848

×
1849
                return nil
×
1850
        }, sqldb.NoOpReset)
1851
        if err != nil {
×
1852
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1853
        }
×
1854

1855
        return numZombies, nil
×
1856
}
1857

1858
// DeleteChannelEdges removes edges with the given channel IDs from the
1859
// database and marks them as zombies. This ensures that we're unable to re-add
1860
// it to our database once again. If an edge does not exist within the
1861
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1862
// true, then when we mark these edges as zombies, we'll set up the keys such
1863
// that we require the node that failed to send the fresh update to be the one
1864
// that resurrects the channel from its zombie state. The markZombie bool
1865
// denotes whether to mark the channel as a zombie.
1866
//
1867
// NOTE: part of the V1Store interface.
1868
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1869
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1870

×
1871
        s.cacheMu.Lock()
×
1872
        defer s.cacheMu.Unlock()
×
1873

×
1874
        // Keep track of which channels we end up finding so that we can
×
1875
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1876
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1877
        for _, chanID := range chanIDs {
×
1878
                chanLookup[chanID] = struct{}{}
×
1879
        }
×
1880

1881
        var (
×
1882
                ctx   = context.TODO()
×
1883
                edges []*models.ChannelEdgeInfo
×
1884
        )
×
1885
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1886
                // First, collect all channel rows.
×
1887
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1888
                chanCallBack := func(ctx context.Context,
×
1889
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1890

×
1891
                        // Deleting the entry from the map indicates that we
×
1892
                        // have found the channel.
×
1893
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1894
                        delete(chanLookup, scid)
×
1895

×
1896
                        channelRows = append(channelRows, row)
×
1897

×
1898
                        return nil
×
1899
                }
×
1900

1901
                err := s.forEachChanWithPoliciesInSCIDList(
×
1902
                        ctx, db, chanCallBack, chanIDs,
×
1903
                )
×
1904
                if err != nil {
×
1905
                        return err
×
1906
                }
×
1907

1908
                if len(chanLookup) > 0 {
×
1909
                        return ErrEdgeNotFound
×
1910
                }
×
1911

1912
                if len(channelRows) == 0 {
×
1913
                        return nil
×
1914
                }
×
1915

1916
                // Batch build all channel edges.
1917
                var chanIDsToDelete []int64
×
1918
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1919
                        ctx, s.cfg, db, channelRows,
×
1920
                )
×
1921
                if err != nil {
×
1922
                        return err
×
1923
                }
×
1924

1925
                if markZombie {
×
1926
                        for i, row := range channelRows {
×
1927
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1928

×
1929
                                err := handleZombieMarking(
×
1930
                                        ctx, db, row, edges[i],
×
1931
                                        strictZombiePruning, scid,
×
1932
                                )
×
1933
                                if err != nil {
×
1934
                                        return fmt.Errorf("unable to mark "+
×
1935
                                                "channel as zombie: %w", err)
×
1936
                                }
×
1937
                        }
1938
                }
1939

1940
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1941
        }, func() {
×
1942
                edges = nil
×
1943

×
1944
                // Re-fill the lookup map.
×
1945
                for _, chanID := range chanIDs {
×
1946
                        chanLookup[chanID] = struct{}{}
×
1947
                }
×
1948
        })
1949
        if err != nil {
×
1950
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1951
                        err)
×
1952
        }
×
1953

1954
        for _, chanID := range chanIDs {
×
1955
                s.rejectCache.remove(chanID)
×
1956
                s.chanCache.remove(chanID)
×
1957
        }
×
1958

1959
        return edges, nil
×
1960
}
1961

1962
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1963
// channel identified by the channel ID. If the channel can't be found, then
1964
// ErrEdgeNotFound is returned. A struct which houses the general information
1965
// for the channel itself is returned as well as two structs that contain the
1966
// routing policies for the channel in either direction.
1967
//
1968
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1969
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1970
// the ChannelEdgeInfo will only include the public keys of each node.
1971
//
1972
// NOTE: part of the V1Store interface.
1973
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1974
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1975
        *models.ChannelEdgePolicy, error) {
×
1976

×
1977
        var (
×
1978
                ctx              = context.TODO()
×
1979
                edge             *models.ChannelEdgeInfo
×
1980
                policy1, policy2 *models.ChannelEdgePolicy
×
1981
                chanIDB          = channelIDToBytes(chanID)
×
1982
        )
×
1983
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1984
                row, err := db.GetChannelBySCIDWithPolicies(
×
1985
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1986
                                Scid:    chanIDB,
×
1987
                                Version: int16(lnwire.GossipVersion1),
×
1988
                        },
×
1989
                )
×
1990
                if errors.Is(err, sql.ErrNoRows) {
×
1991
                        // First check if this edge is perhaps in the zombie
×
1992
                        // index.
×
1993
                        zombie, err := db.GetZombieChannel(
×
1994
                                ctx, sqlc.GetZombieChannelParams{
×
1995
                                        Scid:    chanIDB,
×
1996
                                        Version: int16(lnwire.GossipVersion1),
×
1997
                                },
×
1998
                        )
×
1999
                        if errors.Is(err, sql.ErrNoRows) {
×
2000
                                return ErrEdgeNotFound
×
2001
                        } else if err != nil {
×
2002
                                return fmt.Errorf("unable to check if "+
×
2003
                                        "channel is zombie: %w", err)
×
2004
                        }
×
2005

2006
                        // At this point, we know the channel is a zombie, so
2007
                        // we'll return an error indicating this, and we will
2008
                        // populate the edge info with the public keys of each
2009
                        // party as this is the only information we have about
2010
                        // it.
2011
                        edge = &models.ChannelEdgeInfo{}
×
2012
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
2013
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
2014

×
2015
                        return ErrZombieEdge
×
2016
                } else if err != nil {
×
2017
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2018
                }
×
2019

2020
                node1, node2, err := buildNodeVertices(
×
2021
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2022
                )
×
2023
                if err != nil {
×
2024
                        return err
×
2025
                }
×
2026

2027
                edge, err = getAndBuildEdgeInfo(
×
2028
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2029
                )
×
2030
                if err != nil {
×
2031
                        return fmt.Errorf("unable to build channel info: %w",
×
2032
                                err)
×
2033
                }
×
2034

2035
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2036
                if err != nil {
×
2037
                        return fmt.Errorf("unable to extract channel "+
×
2038
                                "policies: %w", err)
×
2039
                }
×
2040

2041
                policy1, policy2, err = getAndBuildChanPolicies(
×
2042
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2043
                        node1, node2,
×
2044
                )
×
2045
                if err != nil {
×
2046
                        return fmt.Errorf("unable to build channel "+
×
2047
                                "policies: %w", err)
×
2048
                }
×
2049

2050
                return nil
×
2051
        }, sqldb.NoOpReset)
2052
        if err != nil {
×
2053
                // If we are returning the ErrZombieEdge, then we also need to
×
2054
                // return the edge info as the method comment indicates that
×
2055
                // this will be populated when the edge is a zombie.
×
2056
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2057
                        err)
×
2058
        }
×
2059

2060
        return edge, policy1, policy2, nil
×
2061
}
2062

2063
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2064
// the channel identified by the funding outpoint. If the channel can't be
2065
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2066
// information for the channel itself is returned as well as two structs that
2067
// contain the routing policies for the channel in either direction.
2068
//
2069
// NOTE: part of the V1Store interface.
2070
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2071
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2072
        *models.ChannelEdgePolicy, error) {
×
2073

×
2074
        var (
×
2075
                ctx              = context.TODO()
×
2076
                edge             *models.ChannelEdgeInfo
×
2077
                policy1, policy2 *models.ChannelEdgePolicy
×
2078
        )
×
2079
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2080
                row, err := db.GetChannelByOutpointWithPolicies(
×
2081
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2082
                                Outpoint: op.String(),
×
2083
                                Version:  int16(lnwire.GossipVersion1),
×
2084
                        },
×
2085
                )
×
2086
                if errors.Is(err, sql.ErrNoRows) {
×
2087
                        return ErrEdgeNotFound
×
2088
                } else if err != nil {
×
2089
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2090
                }
×
2091

2092
                node1, node2, err := buildNodeVertices(
×
2093
                        row.Node1Pubkey, row.Node2Pubkey,
×
2094
                )
×
2095
                if err != nil {
×
2096
                        return err
×
2097
                }
×
2098

2099
                edge, err = getAndBuildEdgeInfo(
×
2100
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2101
                )
×
2102
                if err != nil {
×
2103
                        return fmt.Errorf("unable to build channel info: %w",
×
2104
                                err)
×
2105
                }
×
2106

2107
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2108
                if err != nil {
×
2109
                        return fmt.Errorf("unable to extract channel "+
×
2110
                                "policies: %w", err)
×
2111
                }
×
2112

2113
                policy1, policy2, err = getAndBuildChanPolicies(
×
2114
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2115
                        node1, node2,
×
2116
                )
×
2117
                if err != nil {
×
2118
                        return fmt.Errorf("unable to build channel "+
×
2119
                                "policies: %w", err)
×
2120
                }
×
2121

2122
                return nil
×
2123
        }, sqldb.NoOpReset)
2124
        if err != nil {
×
2125
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2126
                        err)
×
2127
        }
×
2128

2129
        return edge, policy1, policy2, nil
×
2130
}
2131

2132
// HasChannelEdge returns true if the database knows of a channel edge with the
2133
// passed channel ID, and false otherwise. If an edge with that ID is found
2134
// within the graph, then two time stamps representing the last time the edge
2135
// was updated for both directed edges are returned along with the boolean. If
2136
// it is not found, then the zombie index is checked and its result is returned
2137
// as the second boolean.
2138
//
2139
// NOTE: part of the V1Store interface.
2140
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2141
        bool, error) {
×
2142

×
2143
        ctx := context.TODO()
×
2144

×
2145
        var (
×
2146
                exists          bool
×
2147
                isZombie        bool
×
2148
                node1LastUpdate time.Time
×
2149
                node2LastUpdate time.Time
×
2150
        )
×
2151

×
2152
        // We'll query the cache with the shared lock held to allow multiple
×
2153
        // readers to access values in the cache concurrently if they exist.
×
2154
        s.cacheMu.RLock()
×
2155
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2156
                s.cacheMu.RUnlock()
×
2157
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2158
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2159
                exists, isZombie = entry.flags.unpack()
×
2160

×
2161
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2162
        }
×
2163
        s.cacheMu.RUnlock()
×
2164

×
2165
        s.cacheMu.Lock()
×
2166
        defer s.cacheMu.Unlock()
×
2167

×
2168
        // The item was not found with the shared lock, so we'll acquire the
×
2169
        // exclusive lock and check the cache again in case another method added
×
2170
        // the entry to the cache while no lock was held.
×
2171
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2172
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2173
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2174
                exists, isZombie = entry.flags.unpack()
×
2175

×
2176
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2177
        }
×
2178

2179
        chanIDB := channelIDToBytes(chanID)
×
2180
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2181
                channel, err := db.GetChannelBySCID(
×
2182
                        ctx, sqlc.GetChannelBySCIDParams{
×
2183
                                Scid:    chanIDB,
×
2184
                                Version: int16(lnwire.GossipVersion1),
×
2185
                        },
×
2186
                )
×
2187
                if errors.Is(err, sql.ErrNoRows) {
×
2188
                        // Check if it is a zombie channel.
×
2189
                        isZombie, err = db.IsZombieChannel(
×
2190
                                ctx, sqlc.IsZombieChannelParams{
×
2191
                                        Scid:    chanIDB,
×
2192
                                        Version: int16(lnwire.GossipVersion1),
×
2193
                                },
×
2194
                        )
×
2195
                        if err != nil {
×
2196
                                return fmt.Errorf("could not check if channel "+
×
2197
                                        "is zombie: %w", err)
×
2198
                        }
×
2199

2200
                        return nil
×
2201
                } else if err != nil {
×
2202
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2203
                }
×
2204

2205
                exists = true
×
2206

×
2207
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2208
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2209
                                Version:   int16(lnwire.GossipVersion1),
×
2210
                                ChannelID: channel.ID,
×
2211
                                NodeID:    channel.NodeID1,
×
2212
                        },
×
2213
                )
×
2214
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2215
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2216
                                err)
×
2217
                } else if err == nil {
×
2218
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2219
                }
×
2220

2221
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2222
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2223
                                Version:   int16(lnwire.GossipVersion1),
×
2224
                                ChannelID: channel.ID,
×
2225
                                NodeID:    channel.NodeID2,
×
2226
                        },
×
2227
                )
×
2228
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2229
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2230
                                err)
×
2231
                } else if err == nil {
×
2232
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2233
                }
×
2234

2235
                return nil
×
2236
        }, sqldb.NoOpReset)
2237
        if err != nil {
×
2238
                return time.Time{}, time.Time{}, false, false,
×
2239
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2240
        }
×
2241

2242
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2243
                upd1Time: node1LastUpdate.Unix(),
×
2244
                upd2Time: node2LastUpdate.Unix(),
×
2245
                flags:    packRejectFlags(exists, isZombie),
×
2246
        })
×
2247

×
2248
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2249
}
2250

2251
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2252
// passed channel point (outpoint). If the passed channel doesn't exist within
2253
// the database, then ErrEdgeNotFound is returned.
2254
//
2255
// NOTE: part of the V1Store interface.
2256
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2257
        var (
×
2258
                ctx       = context.TODO()
×
2259
                channelID uint64
×
2260
        )
×
2261
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2262
                chanID, err := db.GetSCIDByOutpoint(
×
2263
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2264
                                Outpoint: chanPoint.String(),
×
2265
                                Version:  int16(lnwire.GossipVersion1),
×
2266
                        },
×
2267
                )
×
2268
                if errors.Is(err, sql.ErrNoRows) {
×
2269
                        return ErrEdgeNotFound
×
2270
                } else if err != nil {
×
2271
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2272
                                err)
×
2273
                }
×
2274

2275
                channelID = byteOrder.Uint64(chanID)
×
2276

×
2277
                return nil
×
2278
        }, sqldb.NoOpReset)
2279
        if err != nil {
×
2280
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2281
        }
×
2282

2283
        return channelID, nil
×
2284
}
2285

2286
// IsPublicNode is a helper method that determines whether the node with the
2287
// given public key is seen as a public node in the graph from the graph's
2288
// source node's point of view.
2289
//
2290
// NOTE: part of the V1Store interface.
2291
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2292
        ctx := context.TODO()
×
2293

×
2294
        var isPublic bool
×
2295
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2296
                var err error
×
2297
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2298

×
2299
                return err
×
2300
        }, sqldb.NoOpReset)
×
2301
        if err != nil {
×
2302
                return false, fmt.Errorf("unable to check if node is "+
×
2303
                        "public: %w", err)
×
2304
        }
×
2305

2306
        return isPublic, nil
×
2307
}
2308

2309
// FetchChanInfos returns the set of channel edges that correspond to the passed
2310
// channel ID's. If an edge is the query is unknown to the database, it will
2311
// skipped and the result will contain only those edges that exist at the time
2312
// of the query. This can be used to respond to peer queries that are seeking to
2313
// fill in gaps in their view of the channel graph.
2314
//
2315
// NOTE: part of the V1Store interface.
2316
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2317
        var (
×
2318
                ctx   = context.TODO()
×
2319
                edges = make(map[uint64]ChannelEdge)
×
2320
        )
×
2321
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2322
                // First, collect all channel rows.
×
2323
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2324
                chanCallBack := func(ctx context.Context,
×
2325
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2326

×
2327
                        channelRows = append(channelRows, row)
×
2328
                        return nil
×
2329
                }
×
2330

2331
                err := s.forEachChanWithPoliciesInSCIDList(
×
2332
                        ctx, db, chanCallBack, chanIDs,
×
2333
                )
×
2334
                if err != nil {
×
2335
                        return err
×
2336
                }
×
2337

2338
                if len(channelRows) == 0 {
×
2339
                        return nil
×
2340
                }
×
2341

2342
                // Batch build all channel edges.
2343
                chans, err := batchBuildChannelEdges(
×
2344
                        ctx, s.cfg, db, channelRows,
×
2345
                )
×
2346
                if err != nil {
×
2347
                        return fmt.Errorf("unable to build channel edges: %w",
×
2348
                                err)
×
2349
                }
×
2350

2351
                for _, c := range chans {
×
2352
                        edges[c.Info.ChannelID] = c
×
2353
                }
×
2354

2355
                return err
×
2356
        }, func() {
×
2357
                clear(edges)
×
2358
        })
×
2359
        if err != nil {
×
2360
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2361
        }
×
2362

2363
        res := make([]ChannelEdge, 0, len(edges))
×
2364
        for _, chanID := range chanIDs {
×
2365
                edge, ok := edges[chanID]
×
2366
                if !ok {
×
2367
                        continue
×
2368
                }
2369

2370
                res = append(res, edge)
×
2371
        }
2372

2373
        return res, nil
×
2374
}
2375

2376
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2377
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2378
// channels in a paginated manner.
2379
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2380
        db SQLQueries, cb func(ctx context.Context,
2381
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2382
        chanIDs []uint64) error {
×
2383

×
2384
        queryWrapper := func(ctx context.Context,
×
2385
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2386
                error) {
×
2387

×
2388
                return db.GetChannelsBySCIDWithPolicies(
×
2389
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2390
                                Version: int16(lnwire.GossipVersion1),
×
2391
                                Scids:   scids,
×
2392
                        },
×
2393
                )
×
2394
        }
×
2395

2396
        return sqldb.ExecuteBatchQuery(
×
2397
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2398
                cb,
×
2399
        )
×
2400
}
2401

2402
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2403
// ID's that we don't know and are not known zombies of the passed set. In other
2404
// words, we perform a set difference of our set of chan ID's and the ones
2405
// passed in. This method can be used by callers to determine the set of
2406
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2407
// known zombies is also returned.
2408
//
2409
// NOTE: part of the V1Store interface.
2410
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2411
        []ChannelUpdateInfo, error) {
×
2412

×
2413
        var (
×
2414
                ctx          = context.TODO()
×
2415
                newChanIDs   []uint64
×
2416
                knownZombies []ChannelUpdateInfo
×
2417
                infoLookup   = make(
×
2418
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2419
                )
×
2420
        )
×
2421

×
2422
        // We first build a lookup map of the channel ID's to the
×
2423
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2424
        // already know about.
×
2425
        for _, chanInfo := range chansInfo {
×
2426
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2427
        }
×
2428

2429
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2430
                // The call-back function deletes known channels from
×
2431
                // infoLookup, so that we can later check which channels are
×
2432
                // zombies by only looking at the remaining channels in the set.
×
2433
                cb := func(ctx context.Context,
×
2434
                        channel sqlc.GraphChannel) error {
×
2435

×
2436
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2437

×
2438
                        return nil
×
2439
                }
×
2440

2441
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2442
                if err != nil {
×
2443
                        return fmt.Errorf("unable to iterate through "+
×
2444
                                "channels: %w", err)
×
2445
                }
×
2446

2447
                // We want to ensure that we deal with the channels in the
2448
                // same order that they were passed in, so we iterate over the
2449
                // original chansInfo slice and then check if that channel is
2450
                // still in the infoLookup map.
2451
                for _, chanInfo := range chansInfo {
×
2452
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2453
                        if _, ok := infoLookup[channelID]; !ok {
×
2454
                                continue
×
2455
                        }
2456

2457
                        isZombie, err := db.IsZombieChannel(
×
2458
                                ctx, sqlc.IsZombieChannelParams{
×
2459
                                        Scid:    channelIDToBytes(channelID),
×
2460
                                        Version: int16(lnwire.GossipVersion1),
×
2461
                                },
×
2462
                        )
×
2463
                        if err != nil {
×
2464
                                return fmt.Errorf("unable to fetch zombie "+
×
2465
                                        "channel: %w", err)
×
2466
                        }
×
2467

2468
                        if isZombie {
×
2469
                                knownZombies = append(knownZombies, chanInfo)
×
2470

×
2471
                                continue
×
2472
                        }
2473

2474
                        newChanIDs = append(newChanIDs, channelID)
×
2475
                }
2476

2477
                return nil
×
2478
        }, func() {
×
2479
                newChanIDs = nil
×
2480
                knownZombies = nil
×
2481
                // Rebuild the infoLookup map in case of a rollback.
×
2482
                for _, chanInfo := range chansInfo {
×
2483
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2484
                        infoLookup[scid] = chanInfo
×
2485
                }
×
2486
        })
2487
        if err != nil {
×
2488
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2489
        }
×
2490

2491
        return newChanIDs, knownZombies, nil
×
2492
}
2493

2494
// forEachChanInSCIDList is a helper method that executes a paged query
2495
// against the database to fetch all channels that match the passed
2496
// ChannelUpdateInfo slice. The callback function is called for each channel
2497
// that is found.
2498
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2499
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2500
        chansInfo []ChannelUpdateInfo) error {
×
2501

×
2502
        queryWrapper := func(ctx context.Context,
×
2503
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2504

×
2505
                return db.GetChannelsBySCIDs(
×
2506
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2507
                                Version: int16(lnwire.GossipVersion1),
×
2508
                                Scids:   scids,
×
2509
                        },
×
2510
                )
×
2511
        }
×
2512

2513
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2514
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2515

×
2516
                return channelIDToBytes(channelID)
×
2517
        }
×
2518

2519
        return sqldb.ExecuteBatchQuery(
×
2520
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2521
                cb,
×
2522
        )
×
2523
}
2524

2525
// PruneGraphNodes is a garbage collection method which attempts to prune out
2526
// any nodes from the channel graph that are currently unconnected. This ensure
2527
// that we only maintain a graph of reachable nodes. In the event that a pruned
2528
// node gains more channels, it will be re-added back to the graph.
2529
//
2530
// NOTE: this prunes nodes across protocol versions. It will never prune the
2531
// source nodes.
2532
//
2533
// NOTE: part of the V1Store interface.
2534
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2535
        var ctx = context.TODO()
×
2536

×
2537
        var prunedNodes []route.Vertex
×
2538
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2539
                var err error
×
2540
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2541

×
2542
                return err
×
2543
        }, func() {
×
2544
                prunedNodes = nil
×
2545
        })
×
2546
        if err != nil {
×
2547
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2548
        }
×
2549

2550
        return prunedNodes, nil
×
2551
}
2552

2553
// PruneGraph prunes newly closed channels from the channel graph in response
2554
// to a new block being solved on the network. Any transactions which spend the
2555
// funding output of any known channels within he graph will be deleted.
2556
// Additionally, the "prune tip", or the last block which has been used to
2557
// prune the graph is stored so callers can ensure the graph is fully in sync
2558
// with the current UTXO state. A slice of channels that have been closed by
2559
// the target block along with any pruned nodes are returned if the function
2560
// succeeds without error.
2561
//
2562
// NOTE: part of the V1Store interface.
2563
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2564
        blockHash *chainhash.Hash, blockHeight uint32) (
2565
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2566

×
2567
        ctx := context.TODO()
×
2568

×
2569
        s.cacheMu.Lock()
×
2570
        defer s.cacheMu.Unlock()
×
2571

×
2572
        var (
×
2573
                closedChans []*models.ChannelEdgeInfo
×
2574
                prunedNodes []route.Vertex
×
2575
        )
×
2576
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2577
                // First, collect all channel rows that need to be pruned.
×
2578
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2579
                channelCallback := func(ctx context.Context,
×
2580
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2581

×
2582
                        channelRows = append(channelRows, row)
×
2583

×
2584
                        return nil
×
2585
                }
×
2586

2587
                err := s.forEachChanInOutpoints(
×
2588
                        ctx, db, spentOutputs, channelCallback,
×
2589
                )
×
2590
                if err != nil {
×
2591
                        return fmt.Errorf("unable to fetch channels by "+
×
2592
                                "outpoints: %w", err)
×
2593
                }
×
2594

2595
                if len(channelRows) == 0 {
×
2596
                        // There are no channels to prune. So we can exit early
×
2597
                        // after updating the prune log.
×
2598
                        err = db.UpsertPruneLogEntry(
×
2599
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2600
                                        BlockHash:   blockHash[:],
×
2601
                                        BlockHeight: int64(blockHeight),
×
2602
                                },
×
2603
                        )
×
2604
                        if err != nil {
×
2605
                                return fmt.Errorf("unable to insert prune log "+
×
2606
                                        "entry: %w", err)
×
2607
                        }
×
2608

2609
                        return nil
×
2610
                }
2611

2612
                // Batch build all channel edges for pruning.
2613
                var chansToDelete []int64
×
2614
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2615
                        ctx, s.cfg, db, channelRows,
×
2616
                )
×
2617
                if err != nil {
×
2618
                        return err
×
2619
                }
×
2620

2621
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2622
                if err != nil {
×
2623
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2624
                }
×
2625

2626
                err = db.UpsertPruneLogEntry(
×
2627
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2628
                                BlockHash:   blockHash[:],
×
2629
                                BlockHeight: int64(blockHeight),
×
2630
                        },
×
2631
                )
×
2632
                if err != nil {
×
2633
                        return fmt.Errorf("unable to insert prune log "+
×
2634
                                "entry: %w", err)
×
2635
                }
×
2636

2637
                // Now that we've pruned some channels, we'll also prune any
2638
                // nodes that no longer have any channels.
2639
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2640
                if err != nil {
×
2641
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2642
                                err)
×
2643
                }
×
2644

2645
                return nil
×
2646
        }, func() {
×
2647
                prunedNodes = nil
×
2648
                closedChans = nil
×
2649
        })
×
2650
        if err != nil {
×
2651
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2652
        }
×
2653

2654
        for _, channel := range closedChans {
×
2655
                s.rejectCache.remove(channel.ChannelID)
×
2656
                s.chanCache.remove(channel.ChannelID)
×
2657
        }
×
2658

2659
        return closedChans, prunedNodes, nil
×
2660
}
2661

2662
// forEachChanInOutpoints is a helper function that executes a paginated
2663
// query to fetch channels by their outpoints and applies the given call-back
2664
// to each.
2665
//
2666
// NOTE: this fetches channels for all protocol versions.
2667
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2668
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2669
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2670

×
2671
        // Create a wrapper that uses the transaction's db instance to execute
×
2672
        // the query.
×
2673
        queryWrapper := func(ctx context.Context,
×
2674
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2675
                error) {
×
2676

×
2677
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2678
        }
×
2679

2680
        // Define the conversion function from Outpoint to string.
2681
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2682
                return outpoint.String()
×
2683
        }
×
2684

2685
        return sqldb.ExecuteBatchQuery(
×
2686
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2687
                queryWrapper, cb,
×
2688
        )
×
2689
}
2690

2691
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2692
        dbIDs []int64) error {
×
2693

×
2694
        // Create a wrapper that uses the transaction's db instance to execute
×
2695
        // the query.
×
2696
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2697
                return nil, db.DeleteChannels(ctx, ids)
×
2698
        }
×
2699

2700
        idConverter := func(id int64) int64 {
×
2701
                return id
×
2702
        }
×
2703

2704
        return sqldb.ExecuteBatchQuery(
×
2705
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2706
                queryWrapper, func(ctx context.Context, _ any) error {
×
2707
                        return nil
×
2708
                },
×
2709
        )
2710
}
2711

2712
// ChannelView returns the verifiable edge information for each active channel
2713
// within the known channel graph. The set of UTXOs (along with their scripts)
2714
// returned are the ones that need to be watched on chain to detect channel
2715
// closes on the resident blockchain.
2716
//
2717
// NOTE: part of the V1Store interface.
2718
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2719
        var (
×
2720
                ctx        = context.TODO()
×
2721
                edgePoints []EdgePoint
×
2722
        )
×
2723

×
2724
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2725
                handleChannel := func(_ context.Context,
×
2726
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2727

×
2728
                        pkScript, err := genMultiSigP2WSH(
×
2729
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2730
                        )
×
2731
                        if err != nil {
×
2732
                                return err
×
2733
                        }
×
2734

2735
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2736
                        if err != nil {
×
2737
                                return err
×
2738
                        }
×
2739

2740
                        edgePoints = append(edgePoints, EdgePoint{
×
2741
                                FundingPkScript: pkScript,
×
2742
                                OutPoint:        *op,
×
2743
                        })
×
2744

×
2745
                        return nil
×
2746
                }
2747

2748
                queryFunc := func(ctx context.Context, lastID int64,
×
2749
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2750

×
2751
                        return db.ListChannelsPaginated(
×
2752
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2753
                                        Version: int16(lnwire.GossipVersion1),
×
2754
                                        ID:      lastID,
×
2755
                                        Limit:   limit,
×
2756
                                },
×
2757
                        )
×
2758
                }
×
2759

2760
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2761
                        return row.ID
×
2762
                }
×
2763

2764
                return sqldb.ExecutePaginatedQuery(
×
2765
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2766
                        extractCursor, handleChannel,
×
2767
                )
×
2768
        }, func() {
×
2769
                edgePoints = nil
×
2770
        })
×
2771
        if err != nil {
×
2772
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2773
        }
×
2774

2775
        return edgePoints, nil
×
2776
}
2777

2778
// PruneTip returns the block height and hash of the latest block that has been
2779
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2780
// to tell if the graph is currently in sync with the current best known UTXO
2781
// state.
2782
//
2783
// NOTE: part of the V1Store interface.
2784
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2785
        var (
×
2786
                ctx       = context.TODO()
×
2787
                tipHash   chainhash.Hash
×
2788
                tipHeight uint32
×
2789
        )
×
2790
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2791
                pruneTip, err := db.GetPruneTip(ctx)
×
2792
                if errors.Is(err, sql.ErrNoRows) {
×
2793
                        return ErrGraphNeverPruned
×
2794
                } else if err != nil {
×
2795
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2796
                }
×
2797

2798
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2799
                tipHeight = uint32(pruneTip.BlockHeight)
×
2800

×
2801
                return nil
×
2802
        }, sqldb.NoOpReset)
2803
        if err != nil {
×
2804
                return nil, 0, err
×
2805
        }
×
2806

2807
        return &tipHash, tipHeight, nil
×
2808
}
2809

2810
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2811
//
2812
// NOTE: this prunes nodes across protocol versions. It will never prune the
2813
// source nodes.
2814
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2815
        db SQLQueries) ([]route.Vertex, error) {
×
2816

×
2817
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2818
        if err != nil {
×
2819
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2820
                        "nodes: %w", err)
×
2821
        }
×
2822

2823
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2824
        for i, nodeKey := range nodeKeys {
×
2825
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2826
                if err != nil {
×
2827
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2828
                                "from bytes: %w", err)
×
2829
                }
×
2830

2831
                prunedNodes[i] = pub
×
2832
        }
2833

2834
        return prunedNodes, nil
×
2835
}
2836

2837
// DisconnectBlockAtHeight is used to indicate that the block specified
2838
// by the passed height has been disconnected from the main chain. This
2839
// will "rewind" the graph back to the height below, deleting channels
2840
// that are no longer confirmed from the graph. The prune log will be
2841
// set to the last prune height valid for the remaining chain.
2842
// Channels that were removed from the graph resulting from the
2843
// disconnected block are returned.
2844
//
2845
// NOTE: part of the V1Store interface.
2846
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2847
        []*models.ChannelEdgeInfo, error) {
×
2848

×
2849
        ctx := context.TODO()
×
2850

×
2851
        var (
×
2852
                // Every channel having a ShortChannelID starting at 'height'
×
2853
                // will no longer be confirmed.
×
2854
                startShortChanID = lnwire.ShortChannelID{
×
2855
                        BlockHeight: height,
×
2856
                }
×
2857

×
2858
                // Delete everything after this height from the db up until the
×
2859
                // SCID alias range.
×
2860
                endShortChanID = aliasmgr.StartingAlias
×
2861

×
2862
                removedChans []*models.ChannelEdgeInfo
×
2863

×
2864
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2865
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2866
        )
×
2867

×
2868
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2869
                rows, err := db.GetChannelsBySCIDRange(
×
2870
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2871
                                StartScid: chanIDStart,
×
2872
                                EndScid:   chanIDEnd,
×
2873
                        },
×
2874
                )
×
2875
                if err != nil {
×
2876
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2877
                }
×
2878

2879
                if len(rows) == 0 {
×
2880
                        // No channels to disconnect, but still clean up prune
×
2881
                        // log.
×
2882
                        return db.DeletePruneLogEntriesInRange(
×
2883
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2884
                                        StartHeight: int64(height),
×
2885
                                        EndHeight: int64(
×
2886
                                                endShortChanID.BlockHeight,
×
2887
                                        ),
×
2888
                                },
×
2889
                        )
×
2890
                }
×
2891

2892
                // Batch build all channel edges for disconnection.
2893
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2894
                        ctx, s.cfg, db, rows,
×
2895
                )
×
2896
                if err != nil {
×
2897
                        return err
×
2898
                }
×
2899

2900
                removedChans = channelEdges
×
2901

×
2902
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2903
                if err != nil {
×
2904
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2905
                }
×
2906

2907
                return db.DeletePruneLogEntriesInRange(
×
2908
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2909
                                StartHeight: int64(height),
×
2910
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2911
                        },
×
2912
                )
×
2913
        }, func() {
×
2914
                removedChans = nil
×
2915
        })
×
2916
        if err != nil {
×
2917
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2918
                        "height: %w", err)
×
2919
        }
×
2920

2921
        for _, channel := range removedChans {
×
2922
                s.rejectCache.remove(channel.ChannelID)
×
2923
                s.chanCache.remove(channel.ChannelID)
×
2924
        }
×
2925

2926
        return removedChans, nil
×
2927
}
2928

2929
// AddEdgeProof sets the proof of an existing edge in the graph database.
2930
//
2931
// NOTE: part of the V1Store interface.
2932
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2933
        proof *models.ChannelAuthProof) error {
×
2934

×
2935
        var (
×
2936
                ctx       = context.TODO()
×
2937
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2938
        )
×
2939

×
2940
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2941
                res, err := db.AddV1ChannelProof(
×
2942
                        ctx, sqlc.AddV1ChannelProofParams{
×
2943
                                Scid:              scidBytes,
×
2944
                                Node1Signature:    proof.NodeSig1Bytes,
×
2945
                                Node2Signature:    proof.NodeSig2Bytes,
×
2946
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2947
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2948
                        },
×
2949
                )
×
2950
                if err != nil {
×
2951
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2952
                }
×
2953

2954
                n, err := res.RowsAffected()
×
2955
                if err != nil {
×
2956
                        return err
×
2957
                }
×
2958

2959
                if n == 0 {
×
2960
                        return fmt.Errorf("no rows affected when adding edge "+
×
2961
                                "proof for SCID %v", scid)
×
2962
                } else if n > 1 {
×
2963
                        return fmt.Errorf("multiple rows affected when adding "+
×
2964
                                "edge proof for SCID %v: %d rows affected",
×
2965
                                scid, n)
×
2966
                }
×
2967

2968
                return nil
×
2969
        }, sqldb.NoOpReset)
2970
        if err != nil {
×
2971
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2972
        }
×
2973

2974
        return nil
×
2975
}
2976

2977
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2978
// that we can ignore channel announcements that we know to be closed without
2979
// having to validate them and fetch a block.
2980
//
2981
// NOTE: part of the V1Store interface.
2982
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2983
        var (
×
2984
                ctx     = context.TODO()
×
2985
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2986
        )
×
2987

×
2988
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2989
                return db.InsertClosedChannel(ctx, chanIDB)
×
2990
        }, sqldb.NoOpReset)
×
2991
}
2992

2993
// IsClosedScid checks whether a channel identified by the passed in scid is
2994
// closed. This helps avoid having to perform expensive validation checks.
2995
//
2996
// NOTE: part of the V1Store interface.
2997
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2998
        var (
×
2999
                ctx      = context.TODO()
×
3000
                isClosed bool
×
3001
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
3002
        )
×
3003
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3004
                var err error
×
3005
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
3006
                if err != nil {
×
3007
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
3008
                                err)
×
3009
                }
×
3010

3011
                return nil
×
3012
        }, sqldb.NoOpReset)
3013
        if err != nil {
×
3014
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3015
                        err)
×
3016
        }
×
3017

3018
        return isClosed, nil
×
3019
}
3020

3021
// GraphSession will provide the call-back with access to a NodeTraverser
3022
// instance which can be used to perform queries against the channel graph.
3023
//
3024
// NOTE: part of the V1Store interface.
3025
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3026
        reset func()) error {
×
3027

×
3028
        var ctx = context.TODO()
×
3029

×
3030
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3031
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3032
        }, reset)
×
3033
}
3034

3035
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3036
// read only transaction for a consistent view of the graph.
3037
type sqlNodeTraverser struct {
3038
        db    SQLQueries
3039
        chain chainhash.Hash
3040
}
3041

3042
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3043
// NodeTraverser interface.
3044
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3045

3046
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3047
func newSQLNodeTraverser(db SQLQueries,
3048
        chain chainhash.Hash) *sqlNodeTraverser {
×
3049

×
3050
        return &sqlNodeTraverser{
×
3051
                db:    db,
×
3052
                chain: chain,
×
3053
        }
×
3054
}
×
3055

3056
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3057
// node.
3058
//
3059
// NOTE: Part of the NodeTraverser interface.
3060
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3061
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3062

×
3063
        ctx := context.TODO()
×
3064

×
3065
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3066
}
×
3067

3068
// FetchNodeFeatures returns the features of the given node. If the node is
3069
// unknown, assume no additional features are supported.
3070
//
3071
// NOTE: Part of the NodeTraverser interface.
3072
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3073
        *lnwire.FeatureVector, error) {
×
3074

×
3075
        ctx := context.TODO()
×
3076

×
3077
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3078
}
×
3079

3080
// forEachNodeDirectedChannel iterates through all channels of a given
3081
// node, executing the passed callback on the directed edge representing the
3082
// channel and its incoming policy. If the node is not found, no error is
3083
// returned.
3084
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3085
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3086

×
3087
        toNodeCallback := func() route.Vertex {
×
3088
                return nodePub
×
3089
        }
×
3090

3091
        dbID, err := db.GetNodeIDByPubKey(
×
3092
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3093
                        Version: int16(lnwire.GossipVersion1),
×
3094
                        PubKey:  nodePub[:],
×
3095
                },
×
3096
        )
×
3097
        if errors.Is(err, sql.ErrNoRows) {
×
3098
                return nil
×
3099
        } else if err != nil {
×
3100
                return fmt.Errorf("unable to fetch node: %w", err)
×
3101
        }
×
3102

3103
        rows, err := db.ListChannelsByNodeID(
×
3104
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3105
                        Version: int16(lnwire.GossipVersion1),
×
3106
                        NodeID1: dbID,
×
3107
                },
×
3108
        )
×
3109
        if err != nil {
×
3110
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3111
        }
×
3112

3113
        // Exit early if there are no channels for this node so we don't
3114
        // do the unnecessary feature fetching.
3115
        if len(rows) == 0 {
×
3116
                return nil
×
3117
        }
×
3118

3119
        features, err := getNodeFeatures(ctx, db, dbID)
×
3120
        if err != nil {
×
3121
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3122
        }
×
3123

3124
        for _, row := range rows {
×
3125
                node1, node2, err := buildNodeVertices(
×
3126
                        row.Node1Pubkey, row.Node2Pubkey,
×
3127
                )
×
3128
                if err != nil {
×
3129
                        return fmt.Errorf("unable to build node vertices: %w",
×
3130
                                err)
×
3131
                }
×
3132

3133
                edge := buildCacheableChannelInfo(
×
3134
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3135
                        node1, node2,
×
3136
                )
×
3137

×
3138
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3139
                if err != nil {
×
3140
                        return err
×
3141
                }
×
3142

3143
                p1, p2, err := buildCachedChanPolicies(
×
3144
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3145
                )
×
3146
                if err != nil {
×
3147
                        return err
×
3148
                }
×
3149

3150
                // Determine the outgoing and incoming policy for this
3151
                // channel and node combo.
3152
                outPolicy, inPolicy := p1, p2
×
3153
                if p1 != nil && node2 == nodePub {
×
3154
                        outPolicy, inPolicy = p2, p1
×
3155
                } else if p2 != nil && node1 != nodePub {
×
3156
                        outPolicy, inPolicy = p2, p1
×
3157
                }
×
3158

3159
                var cachedInPolicy *models.CachedEdgePolicy
×
3160
                if inPolicy != nil {
×
3161
                        cachedInPolicy = inPolicy
×
3162
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3163
                        cachedInPolicy.ToNodeFeatures = features
×
3164
                }
×
3165

3166
                directedChannel := &DirectedChannel{
×
3167
                        ChannelID:    edge.ChannelID,
×
3168
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3169
                        OtherNode:    edge.NodeKey2Bytes,
×
3170
                        Capacity:     edge.Capacity,
×
3171
                        OutPolicySet: outPolicy != nil,
×
3172
                        InPolicy:     cachedInPolicy,
×
3173
                }
×
3174
                if outPolicy != nil {
×
3175
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3176
                                directedChannel.InboundFee = fee
×
3177
                        })
×
3178
                }
3179

3180
                if nodePub == edge.NodeKey2Bytes {
×
3181
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3182
                }
×
3183

3184
                if err := cb(directedChannel); err != nil {
×
3185
                        return err
×
3186
                }
×
3187
        }
3188

3189
        return nil
×
3190
}
3191

3192
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3193
// and executes the provided callback for each node. It does so via pagination
3194
// along with batch loading of the node feature bits.
3195
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3196
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3197
                features *lnwire.FeatureVector) error) error {
×
3198

×
3199
        handleNode := func(_ context.Context,
×
3200
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3201
                featureBits map[int64][]int) error {
×
3202

×
3203
                fv := lnwire.EmptyFeatureVector()
×
3204
                if features, exists := featureBits[dbNode.ID]; exists {
×
3205
                        for _, bit := range features {
×
3206
                                fv.Set(lnwire.FeatureBit(bit))
×
3207
                        }
×
3208
                }
3209

3210
                var pub route.Vertex
×
3211
                copy(pub[:], dbNode.PubKey)
×
3212

×
3213
                return processNode(dbNode.ID, pub, fv)
×
3214
        }
3215

3216
        queryFunc := func(ctx context.Context, lastID int64,
×
3217
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3218

×
3219
                return db.ListNodeIDsAndPubKeys(
×
3220
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3221
                                Version: int16(lnwire.GossipVersion1),
×
3222
                                ID:      lastID,
×
3223
                                Limit:   limit,
×
3224
                        },
×
3225
                )
×
3226
        }
×
3227

3228
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3229
                return row.ID
×
3230
        }
×
3231

3232
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3233
                return node.ID, nil
×
3234
        }
×
3235

3236
        batchQueryFunc := func(ctx context.Context,
×
3237
                nodeIDs []int64) (map[int64][]int, error) {
×
3238

×
3239
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3240
        }
×
3241

3242
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3243
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3244
                batchQueryFunc, handleNode,
×
3245
        )
×
3246
}
3247

3248
// forEachNodeChannel iterates through all channels of a node, executing
3249
// the passed callback on each. The call-back is provided with the channel's
3250
// edge information, the outgoing policy and the incoming policy for the
3251
// channel and node combo.
3252
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3253
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3254
                *models.ChannelEdgePolicy,
3255
                *models.ChannelEdgePolicy) error) error {
×
3256

×
3257
        // Get all the V1 channels for this node.
×
3258
        rows, err := db.ListChannelsByNodeID(
×
3259
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3260
                        Version: int16(lnwire.GossipVersion1),
×
3261
                        NodeID1: id,
×
3262
                },
×
3263
        )
×
3264
        if err != nil {
×
3265
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3266
        }
×
3267

3268
        // Collect all the channel and policy IDs.
3269
        var (
×
3270
                chanIDs   = make([]int64, 0, len(rows))
×
3271
                policyIDs = make([]int64, 0, 2*len(rows))
×
3272
        )
×
3273
        for _, row := range rows {
×
3274
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3275

×
3276
                if row.Policy1ID.Valid {
×
3277
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3278
                }
×
3279
                if row.Policy2ID.Valid {
×
3280
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3281
                }
×
3282
        }
3283

3284
        batchData, err := batchLoadChannelData(
×
3285
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3286
        )
×
3287
        if err != nil {
×
3288
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3289
        }
×
3290

3291
        // Call the call-back for each channel and its known policies.
3292
        for _, row := range rows {
×
3293
                node1, node2, err := buildNodeVertices(
×
3294
                        row.Node1Pubkey, row.Node2Pubkey,
×
3295
                )
×
3296
                if err != nil {
×
3297
                        return fmt.Errorf("unable to build node vertices: %w",
×
3298
                                err)
×
3299
                }
×
3300

3301
                edge, err := buildEdgeInfoWithBatchData(
×
3302
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3303
                        batchData,
×
3304
                )
×
3305
                if err != nil {
×
3306
                        return fmt.Errorf("unable to build channel info: %w",
×
3307
                                err)
×
3308
                }
×
3309

3310
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3311
                if err != nil {
×
3312
                        return fmt.Errorf("unable to extract channel "+
×
3313
                                "policies: %w", err)
×
3314
                }
×
3315

3316
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3317
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3318
                )
×
3319
                if err != nil {
×
3320
                        return fmt.Errorf("unable to build channel "+
×
3321
                                "policies: %w", err)
×
3322
                }
×
3323

3324
                // Determine the outgoing and incoming policy for this
3325
                // channel and node combo.
3326
                p1ToNode := row.GraphChannel.NodeID2
×
3327
                p2ToNode := row.GraphChannel.NodeID1
×
3328
                outPolicy, inPolicy := p1, p2
×
3329
                if (p1 != nil && p1ToNode == id) ||
×
3330
                        (p2 != nil && p2ToNode != id) {
×
3331

×
3332
                        outPolicy, inPolicy = p2, p1
×
3333
                }
×
3334

3335
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3336
                        return err
×
3337
                }
×
3338
        }
3339

3340
        return nil
×
3341
}
3342

3343
// updateChanEdgePolicy upserts the channel policy info we have stored for
3344
// a channel we already know of.
3345
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3346
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3347
        error) {
×
3348

×
3349
        var (
×
3350
                node1Pub, node2Pub route.Vertex
×
3351
                isNode1            bool
×
3352
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3353
        )
×
3354

×
3355
        // Check that this edge policy refers to a channel that we already
×
3356
        // know of. We do this explicitly so that we can return the appropriate
×
3357
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3358
        // abort the transaction which would abort the entire batch.
×
3359
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3360
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3361
                        Scid:    chanIDB,
×
3362
                        Version: int16(lnwire.GossipVersion1),
×
3363
                },
×
3364
        )
×
3365
        if errors.Is(err, sql.ErrNoRows) {
×
3366
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3367
        } else if err != nil {
×
3368
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3369
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3370
        }
×
3371

3372
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3373
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3374

×
3375
        // Figure out which node this edge is from.
×
3376
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3377
        nodeID := dbChan.NodeID1
×
3378
        if !isNode1 {
×
3379
                nodeID = dbChan.NodeID2
×
3380
        }
×
3381

3382
        var (
×
3383
                inboundBase sql.NullInt64
×
3384
                inboundRate sql.NullInt64
×
3385
        )
×
3386
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3387
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3388
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3389
        })
×
3390

3391
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3392
                Version:     int16(lnwire.GossipVersion1),
×
3393
                ChannelID:   dbChan.ID,
×
3394
                NodeID:      nodeID,
×
3395
                Timelock:    int32(edge.TimeLockDelta),
×
3396
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3397
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3398
                MinHtlcMsat: int64(edge.MinHTLC),
×
3399
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3400
                Disabled: sql.NullBool{
×
3401
                        Valid: true,
×
3402
                        Bool:  edge.IsDisabled(),
×
3403
                },
×
3404
                MaxHtlcMsat: sql.NullInt64{
×
3405
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3406
                        Int64: int64(edge.MaxHTLC),
×
3407
                },
×
3408
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3409
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3410
                InboundBaseFeeMsat:      inboundBase,
×
3411
                InboundFeeRateMilliMsat: inboundRate,
×
3412
                Signature:               edge.SigBytes,
×
3413
        })
×
3414
        if err != nil {
×
3415
                return node1Pub, node2Pub, isNode1,
×
3416
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3417
        }
×
3418

3419
        // Convert the flat extra opaque data into a map of TLV types to
3420
        // values.
3421
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3422
        if err != nil {
×
3423
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3424
                        "marshal extra opaque data: %w", err)
×
3425
        }
×
3426

3427
        // Update the channel policy's extra signed fields.
3428
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3429
        if err != nil {
×
3430
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3431
                        "policy extra TLVs: %w", err)
×
3432
        }
×
3433

3434
        return node1Pub, node2Pub, isNode1, nil
×
3435
}
3436

3437
// getNodeByPubKey attempts to look up a target node by its public key.
3438
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3439
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3440

×
3441
        dbNode, err := db.GetNodeByPubKey(
×
3442
                ctx, sqlc.GetNodeByPubKeyParams{
×
3443
                        Version: int16(lnwire.GossipVersion1),
×
3444
                        PubKey:  pubKey[:],
×
3445
                },
×
3446
        )
×
3447
        if errors.Is(err, sql.ErrNoRows) {
×
3448
                return 0, nil, ErrGraphNodeNotFound
×
3449
        } else if err != nil {
×
3450
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3451
        }
×
3452

3453
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3454
        if err != nil {
×
3455
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3456
        }
×
3457

3458
        return dbNode.ID, node, nil
×
3459
}
3460

3461
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3462
// provided parameters.
3463
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3464
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3465

×
3466
        return &models.CachedEdgeInfo{
×
3467
                ChannelID:     byteOrder.Uint64(scid),
×
3468
                NodeKey1Bytes: node1Pub,
×
3469
                NodeKey2Bytes: node2Pub,
×
3470
                Capacity:      btcutil.Amount(capacity),
×
3471
        }
×
3472
}
×
3473

3474
// buildNode constructs a Node instance from the given database node
3475
// record. The node's features, addresses and extra signed fields are also
3476
// fetched from the database and set on the node.
3477
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3478
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3479

×
3480
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3481
        if err != nil {
×
3482
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3483
                        err)
×
3484
        }
×
3485

3486
        return buildNodeWithBatchData(dbNode, data)
×
3487
}
3488

3489
// buildNodeWithBatchData builds a models.Node instance
3490
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3491
// features/addresses/extra fields, then the corresponding fields are expected
3492
// to be present in the batchNodeData.
3493
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3494
        batchData *batchNodeData) (*models.Node, error) {
×
3495

×
3496
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
3497
                return nil, fmt.Errorf("unsupported node version: %d",
×
3498
                        dbNode.Version)
×
3499
        }
×
3500

3501
        var pub [33]byte
×
3502
        copy(pub[:], dbNode.PubKey)
×
3503

×
3504
        node := models.NewV1ShellNode(pub)
×
3505

×
3506
        if len(dbNode.Signature) == 0 {
×
3507
                return node, nil
×
3508
        }
×
3509

3510
        node.AuthSigBytes = dbNode.Signature
×
3511

×
3512
        if dbNode.Alias.Valid {
×
3513
                node.Alias = fn.Some(dbNode.Alias.String)
×
3514
        }
×
3515
        if dbNode.LastUpdate.Valid {
×
3516
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3517
        }
×
3518

3519
        var err error
×
3520
        if dbNode.Color.Valid {
×
3521
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3522
                if err != nil {
×
3523
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3524
                                err)
×
3525
                }
×
3526

3527
                node.Color = fn.Some(nodeColor)
×
3528
        }
3529

3530
        // Use preloaded features.
3531
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3532
                fv := lnwire.EmptyFeatureVector()
×
3533
                for _, bit := range features {
×
3534
                        fv.Set(lnwire.FeatureBit(bit))
×
3535
                }
×
3536
                node.Features = fv
×
3537
        }
3538

3539
        // Use preloaded addresses.
3540
        addresses, exists := batchData.addresses[dbNode.ID]
×
3541
        if exists && len(addresses) > 0 {
×
3542
                node.Addresses, err = buildNodeAddresses(addresses)
×
3543
                if err != nil {
×
3544
                        return nil, fmt.Errorf("unable to build addresses "+
×
3545
                                "for node(%d): %w", dbNode.ID, err)
×
3546
                }
×
3547
        }
3548

3549
        // Use preloaded extra fields.
3550
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3551
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3552
                if err != nil {
×
3553
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3554
                                "signed fields: %w", err)
×
3555
                }
×
3556
                if len(recs) != 0 {
×
3557
                        node.ExtraOpaqueData = recs
×
3558
                }
×
3559
        }
3560

3561
        return node, nil
×
3562
}
3563

3564
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3565
// with the preloaded data, and executes the provided callback for each node.
3566
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3567
        db SQLQueries, nodes []sqlc.GraphNode,
3568
        cb func(dbID int64, node *models.Node) error) error {
×
3569

×
3570
        // Extract node IDs for batch loading.
×
3571
        nodeIDs := make([]int64, len(nodes))
×
3572
        for i, node := range nodes {
×
3573
                nodeIDs[i] = node.ID
×
3574
        }
×
3575

3576
        // Batch load all related data for this page.
3577
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3578
        if err != nil {
×
3579
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3580
        }
×
3581

3582
        for _, dbNode := range nodes {
×
3583
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3584
                if err != nil {
×
3585
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3586
                                dbNode.ID, err)
×
3587
                }
×
3588

3589
                if err := cb(dbNode.ID, node); err != nil {
×
3590
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3591
                                dbNode.ID, err)
×
3592
                }
×
3593
        }
3594

3595
        return nil
×
3596
}
3597

3598
// getNodeFeatures fetches the feature bits and constructs the feature vector
3599
// for a node with the given DB ID.
3600
func getNodeFeatures(ctx context.Context, db SQLQueries,
3601
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3602

×
3603
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3604
        if err != nil {
×
3605
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3606
                        nodeID, err)
×
3607
        }
×
3608

3609
        features := lnwire.EmptyFeatureVector()
×
3610
        for _, feature := range rows {
×
3611
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3612
        }
×
3613

3614
        return features, nil
×
3615
}
3616

3617
// upsertNodeAncillaryData updates the node's features, addresses, and extra
3618
// signed fields. This is common logic shared by upsertNode and
3619
// upsertSourceNode.
3620
func upsertNodeAncillaryData(ctx context.Context, db SQLQueries,
3621
        nodeID int64, node *models.Node) error {
×
3622

×
3623
        // Update the node's features.
×
3624
        err := upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3625
        if err != nil {
×
3626
                return fmt.Errorf("inserting node features: %w", err)
×
3627
        }
×
3628

3629
        // Update the node's addresses.
3630
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3631
        if err != nil {
×
3632
                return fmt.Errorf("inserting node addresses: %w", err)
×
3633
        }
×
3634

3635
        // Convert the flat extra opaque data into a map of TLV types to
3636
        // values.
3637
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3638
        if err != nil {
×
3639
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
3640
                        err)
×
3641
        }
×
3642

3643
        // Update the node's extra signed fields.
3644
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3645
        if err != nil {
×
3646
                return fmt.Errorf("inserting node extra TLVs: %w", err)
×
3647
        }
×
3648

3649
        return nil
×
3650
}
3651

3652
// populateNodeParams populates the common node parameters from a models.Node.
3653
// This is a helper for building UpsertNodeParams and UpsertSourceNodeParams.
3654
func populateNodeParams(node *models.Node,
3655
        setParams func(lastUpdate sql.NullInt64, alias,
3656
                colorStr sql.NullString, signature []byte)) error {
×
3657

×
3658
        if !node.HaveAnnouncement() {
×
3659
                return nil
×
3660
        }
×
3661

3662
        switch node.Version {
×
3663
        case lnwire.GossipVersion1:
×
3664
                lastUpdate := sqldb.SQLInt64(node.LastUpdate.Unix())
×
3665
                var alias, colorStr sql.NullString
×
3666

×
3667
                node.Color.WhenSome(func(rgba color.RGBA) {
×
3668
                        colorStr = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
3669
                })
×
3670
                node.Alias.WhenSome(func(s string) {
×
3671
                        alias = sqldb.SQLStrValid(s)
×
3672
                })
×
3673

3674
                setParams(lastUpdate, alias, colorStr, node.AuthSigBytes)
×
3675

3676
        case lnwire.GossipVersion2:
×
3677
                // No-op for now.
3678

3679
        default:
×
3680
                return fmt.Errorf("unknown gossip version: %d", node.Version)
×
3681
        }
3682

3683
        return nil
×
3684
}
3685

3686
// buildNodeUpsertParams builds the parameters for upserting a node using the
3687
// strict UpsertNode query (requires timestamp to be increasing).
3688
func buildNodeUpsertParams(node *models.Node) (sqlc.UpsertNodeParams, error) {
×
3689
        params := sqlc.UpsertNodeParams{
×
3690
                Version: int16(lnwire.GossipVersion1),
×
3691
                PubKey:  node.PubKeyBytes[:],
×
3692
        }
×
3693

×
3694
        err := populateNodeParams(
×
3695
                node, func(lastUpdate sql.NullInt64, alias,
×
3696
                        colorStr sql.NullString,
×
3697
                        signature []byte) {
×
3698

×
3699
                        params.LastUpdate = lastUpdate
×
3700
                        params.Alias = alias
×
3701
                        params.Color = colorStr
×
3702
                        params.Signature = signature
×
3703
                })
×
3704

3705
        return params, err
×
3706
}
3707

3708
// buildSourceNodeUpsertParams builds the parameters for upserting the source
3709
// node using the lenient UpsertSourceNode query (allows same timestamp).
3710
func buildSourceNodeUpsertParams(node *models.Node) (
3711
        sqlc.UpsertSourceNodeParams, error) {
×
3712

×
3713
        params := sqlc.UpsertSourceNodeParams{
×
3714
                Version: int16(lnwire.GossipVersion1),
×
3715
                PubKey:  node.PubKeyBytes[:],
×
3716
        }
×
3717

×
3718
        err := populateNodeParams(
×
3719
                node, func(lastUpdate sql.NullInt64, alias,
×
3720
                        colorStr sql.NullString, signature []byte) {
×
3721

×
3722
                        params.LastUpdate = lastUpdate
×
3723
                        params.Alias = alias
×
3724
                        params.Color = colorStr
×
3725
                        params.Signature = signature
×
3726
                },
×
3727
        )
3728

3729
        return params, err
×
3730
}
3731

3732
// upsertSourceNode upserts the source node record into the database using a
3733
// less strict upsert that allows updates even when the timestamp hasn't
3734
// changed. This is necessary to handle concurrent updates to our own node
3735
// during startup and runtime. The node's features, addresses and extra TLV
3736
// types are also updated. The node's DB ID is returned.
3737
func upsertSourceNode(ctx context.Context, db SQLQueries,
3738
        node *models.Node) (int64, error) {
×
3739

×
3740
        params, err := buildSourceNodeUpsertParams(node)
×
3741
        if err != nil {
×
3742
                return 0, err
×
3743
        }
×
3744

3745
        nodeID, err := db.UpsertSourceNode(ctx, params)
×
3746
        if err != nil {
×
3747
                return 0, fmt.Errorf("upserting source node(%x): %w",
×
3748
                        node.PubKeyBytes, err)
×
3749
        }
×
3750

3751
        // We can exit here if we don't have the announcement yet.
3752
        if !node.HaveAnnouncement() {
×
3753
                return nodeID, nil
×
3754
        }
×
3755

3756
        // Update the ancillary node data (features, addresses, extra fields).
3757
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3758
        if err != nil {
×
3759
                return 0, err
×
3760
        }
×
3761

3762
        return nodeID, nil
×
3763
}
3764

3765
// upsertNode upserts the node record into the database. If the node already
3766
// exists, then the node's information is updated. If the node doesn't exist,
3767
// then a new node is created. The node's features, addresses and extra TLV
3768
// types are also updated. The node's DB ID is returned.
3769
func upsertNode(ctx context.Context, db SQLQueries,
3770
        node *models.Node) (int64, error) {
×
3771

×
3772
        params, err := buildNodeUpsertParams(node)
×
3773
        if err != nil {
×
3774
                return 0, err
×
3775
        }
×
3776

3777
        nodeID, err := db.UpsertNode(ctx, params)
×
3778
        if err != nil {
×
3779
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3780
                        err)
×
3781
        }
×
3782

3783
        // We can exit here if we don't have the announcement yet.
3784
        if !node.HaveAnnouncement() {
×
3785
                return nodeID, nil
×
3786
        }
×
3787

3788
        // Update the ancillary node data (features, addresses, extra fields).
3789
        err = upsertNodeAncillaryData(ctx, db, nodeID, node)
×
3790
        if err != nil {
×
3791
                return 0, err
×
3792
        }
×
3793

3794
        return nodeID, nil
×
3795
}
3796

3797
// upsertNodeFeatures updates the node's features node_features table. This
3798
// includes deleting any feature bits no longer present and inserting any new
3799
// feature bits. If the feature bit does not yet exist in the features table,
3800
// then an entry is created in that table first.
3801
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3802
        features *lnwire.FeatureVector) error {
×
3803

×
3804
        // Get any existing features for the node.
×
3805
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3806
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3807
                return err
×
3808
        }
×
3809

3810
        // Copy the nodes latest set of feature bits.
3811
        newFeatures := make(map[int32]struct{})
×
3812
        if features != nil {
×
3813
                for feature := range features.Features() {
×
3814
                        newFeatures[int32(feature)] = struct{}{}
×
3815
                }
×
3816
        }
3817

3818
        // For any current feature that already exists in the DB, remove it from
3819
        // the in-memory map. For any existing feature that does not exist in
3820
        // the in-memory map, delete it from the database.
3821
        for _, feature := range existingFeatures {
×
3822
                // The feature is still present, so there are no updates to be
×
3823
                // made.
×
3824
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3825
                        delete(newFeatures, feature.FeatureBit)
×
3826
                        continue
×
3827
                }
3828

3829
                // The feature is no longer present, so we remove it from the
3830
                // database.
3831
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3832
                        NodeID:     nodeID,
×
3833
                        FeatureBit: feature.FeatureBit,
×
3834
                })
×
3835
                if err != nil {
×
3836
                        return fmt.Errorf("unable to delete node(%d) "+
×
3837
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3838
                                err)
×
3839
                }
×
3840
        }
3841

3842
        // Any remaining entries in newFeatures are new features that need to be
3843
        // added to the database for the first time.
3844
        for feature := range newFeatures {
×
3845
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3846
                        NodeID:     nodeID,
×
3847
                        FeatureBit: feature,
×
3848
                })
×
3849
                if err != nil {
×
3850
                        return fmt.Errorf("unable to insert node(%d) "+
×
3851
                                "feature(%v): %w", nodeID, feature, err)
×
3852
                }
×
3853
        }
3854

3855
        return nil
×
3856
}
3857

3858
// fetchNodeFeatures fetches the features for a node with the given public key.
3859
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3860
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3861

×
3862
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3863
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3864
                        PubKey:  nodePub[:],
×
3865
                        Version: int16(lnwire.GossipVersion1),
×
3866
                },
×
3867
        )
×
3868
        if err != nil {
×
3869
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3870
                        nodePub, err)
×
3871
        }
×
3872

3873
        features := lnwire.EmptyFeatureVector()
×
3874
        for _, bit := range rows {
×
3875
                features.Set(lnwire.FeatureBit(bit))
×
3876
        }
×
3877

3878
        return features, nil
×
3879
}
3880

3881
// dbAddressType is an enum type that represents the different address types
3882
// that we store in the node_addresses table. The address type determines how
3883
// the address is to be serialised/deserialize.
3884
type dbAddressType uint8
3885

3886
const (
3887
        addressTypeIPv4   dbAddressType = 1
3888
        addressTypeIPv6   dbAddressType = 2
3889
        addressTypeTorV2  dbAddressType = 3
3890
        addressTypeTorV3  dbAddressType = 4
3891
        addressTypeDNS    dbAddressType = 5
3892
        addressTypeOpaque dbAddressType = math.MaxInt8
3893
)
3894

3895
// collectAddressRecords collects the addresses from the provided
3896
// net.Addr slice and returns a map of dbAddressType to a slice of address
3897
// strings.
3898
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3899
        error) {
×
3900

×
3901
        // Copy the nodes latest set of addresses.
×
3902
        newAddresses := map[dbAddressType][]string{
×
3903
                addressTypeIPv4:   {},
×
3904
                addressTypeIPv6:   {},
×
3905
                addressTypeTorV2:  {},
×
3906
                addressTypeTorV3:  {},
×
3907
                addressTypeDNS:    {},
×
3908
                addressTypeOpaque: {},
×
3909
        }
×
3910
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3911
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3912
        }
×
3913

3914
        for _, address := range addresses {
×
3915
                switch addr := address.(type) {
×
3916
                case *net.TCPAddr:
×
3917
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3918
                                addAddr(addressTypeIPv4, addr)
×
3919
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3920
                                addAddr(addressTypeIPv6, addr)
×
3921
                        } else {
×
3922
                                return nil, fmt.Errorf("unhandled IP "+
×
3923
                                        "address: %v", addr)
×
3924
                        }
×
3925

3926
                case *tor.OnionAddr:
×
3927
                        switch len(addr.OnionService) {
×
3928
                        case tor.V2Len:
×
3929
                                addAddr(addressTypeTorV2, addr)
×
3930
                        case tor.V3Len:
×
3931
                                addAddr(addressTypeTorV3, addr)
×
3932
                        default:
×
3933
                                return nil, fmt.Errorf("invalid length for " +
×
3934
                                        "a tor address")
×
3935
                        }
3936

3937
                case *lnwire.DNSAddress:
×
3938
                        addAddr(addressTypeDNS, addr)
×
3939

3940
                case *lnwire.OpaqueAddrs:
×
3941
                        addAddr(addressTypeOpaque, addr)
×
3942

3943
                default:
×
3944
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3945
                                addr)
×
3946
                }
3947
        }
3948

3949
        return newAddresses, nil
×
3950
}
3951

3952
// upsertNodeAddresses updates the node's addresses in the database. This
3953
// includes deleting any existing addresses and inserting the new set of
3954
// addresses. The deletion is necessary since the ordering of the addresses may
3955
// change, and we need to ensure that the database reflects the latest set of
3956
// addresses so that at the time of reconstructing the node announcement, the
3957
// order is preserved and the signature over the message remains valid.
3958
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3959
        addresses []net.Addr) error {
×
3960

×
3961
        // Delete any existing addresses for the node. This is required since
×
3962
        // even if the new set of addresses is the same, the ordering may have
×
3963
        // changed for a given address type.
×
3964
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3965
        if err != nil {
×
3966
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3967
                        nodeID, err)
×
3968
        }
×
3969

3970
        newAddresses, err := collectAddressRecords(addresses)
×
3971
        if err != nil {
×
3972
                return err
×
3973
        }
×
3974

3975
        // Any remaining entries in newAddresses are new addresses that need to
3976
        // be added to the database for the first time.
3977
        for addrType, addrList := range newAddresses {
×
3978
                for position, addr := range addrList {
×
3979
                        err := db.UpsertNodeAddress(
×
3980
                                ctx, sqlc.UpsertNodeAddressParams{
×
3981
                                        NodeID:   nodeID,
×
3982
                                        Type:     int16(addrType),
×
3983
                                        Address:  addr,
×
3984
                                        Position: int32(position),
×
3985
                                },
×
3986
                        )
×
3987
                        if err != nil {
×
3988
                                return fmt.Errorf("unable to insert "+
×
3989
                                        "node(%d) address(%v): %w", nodeID,
×
3990
                                        addr, err)
×
3991
                        }
×
3992
                }
3993
        }
3994

3995
        return nil
×
3996
}
3997

3998
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3999
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
4000
        error) {
×
4001

×
4002
        // GetNodeAddresses ensures that the addresses for a given type are
×
4003
        // returned in the same order as they were inserted.
×
4004
        rows, err := db.GetNodeAddresses(ctx, id)
×
4005
        if err != nil {
×
4006
                return nil, err
×
4007
        }
×
4008

4009
        addresses := make([]net.Addr, 0, len(rows))
×
4010
        for _, row := range rows {
×
4011
                address := row.Address
×
4012

×
4013
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
4014
                if err != nil {
×
4015
                        return nil, fmt.Errorf("unable to parse address "+
×
4016
                                "for node(%d): %v: %w", id, address, err)
×
4017
                }
×
4018

4019
                addresses = append(addresses, addr)
×
4020
        }
4021

4022
        // If we have no addresses, then we'll return nil instead of an
4023
        // empty slice.
4024
        if len(addresses) == 0 {
×
4025
                addresses = nil
×
4026
        }
×
4027

4028
        return addresses, nil
×
4029
}
4030

4031
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
4032
// database. This includes updating any existing types, inserting any new types,
4033
// and deleting any types that are no longer present.
4034
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
4035
        nodeID int64, extraFields map[uint64][]byte) error {
×
4036

×
4037
        // Get any existing extra signed fields for the node.
×
4038
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
4039
        if err != nil {
×
4040
                return err
×
4041
        }
×
4042

4043
        // Make a lookup map of the existing field types so that we can use it
4044
        // to keep track of any fields we should delete.
4045
        m := make(map[uint64]bool)
×
4046
        for _, field := range existingFields {
×
4047
                m[uint64(field.Type)] = true
×
4048
        }
×
4049

4050
        // For all the new fields, we'll upsert them and remove them from the
4051
        // map of existing fields.
4052
        for tlvType, value := range extraFields {
×
4053
                err = db.UpsertNodeExtraType(
×
4054
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
4055
                                NodeID: nodeID,
×
4056
                                Type:   int64(tlvType),
×
4057
                                Value:  value,
×
4058
                        },
×
4059
                )
×
4060
                if err != nil {
×
4061
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
4062
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4063
                }
×
4064

4065
                // Remove the field from the map of existing fields if it was
4066
                // present.
4067
                delete(m, tlvType)
×
4068
        }
4069

4070
        // For all the fields that are left in the map of existing fields, we'll
4071
        // delete them as they are no longer present in the new set of fields.
4072
        for tlvType := range m {
×
4073
                err = db.DeleteExtraNodeType(
×
4074
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
4075
                                NodeID: nodeID,
×
4076
                                Type:   int64(tlvType),
×
4077
                        },
×
4078
                )
×
4079
                if err != nil {
×
4080
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
4081
                                "signed field(%v): %w", nodeID, tlvType, err)
×
4082
                }
×
4083
        }
4084

4085
        return nil
×
4086
}
4087

4088
// srcNodeInfo holds the information about the source node of the graph.
4089
type srcNodeInfo struct {
4090
        // id is the DB level ID of the source node entry in the "nodes" table.
4091
        id int64
4092

4093
        // pub is the public key of the source node.
4094
        pub route.Vertex
4095
}
4096

4097
// sourceNode returns the DB node ID and pub key of the source node for the
4098
// specified protocol version.
4099
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
4100
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
4101

×
4102
        s.srcNodeMu.Lock()
×
4103
        defer s.srcNodeMu.Unlock()
×
4104

×
4105
        // If we already have the source node ID and pub key cached, then
×
4106
        // return them.
×
4107
        if info, ok := s.srcNodes[version]; ok {
×
4108
                return info.id, info.pub, nil
×
4109
        }
×
4110

4111
        var pubKey route.Vertex
×
4112

×
4113
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
4114
        if err != nil {
×
4115
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
4116
                        err)
×
4117
        }
×
4118

4119
        if len(nodes) == 0 {
×
4120
                return 0, pubKey, ErrSourceNodeNotSet
×
4121
        } else if len(nodes) > 1 {
×
4122
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4123
                        "protocol %s found", version)
×
4124
        }
×
4125

4126
        copy(pubKey[:], nodes[0].PubKey)
×
4127

×
4128
        s.srcNodes[version] = &srcNodeInfo{
×
4129
                id:  nodes[0].NodeID,
×
4130
                pub: pubKey,
×
4131
        }
×
4132

×
4133
        return nodes[0].NodeID, pubKey, nil
×
4134
}
4135

4136
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4137
// This then produces a map from TLV type to value. If the input is not a
4138
// valid TLV stream, then an error is returned.
4139
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4140
        r := bytes.NewReader(data)
×
4141

×
4142
        tlvStream, err := tlv.NewStream()
×
4143
        if err != nil {
×
4144
                return nil, err
×
4145
        }
×
4146

4147
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4148
        // pass it into the P2P decoding variant.
4149
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4150
        if err != nil {
×
4151
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4152
        }
×
4153
        if len(parsedTypes) == 0 {
×
4154
                return nil, nil
×
4155
        }
×
4156

4157
        records := make(map[uint64][]byte)
×
4158
        for k, v := range parsedTypes {
×
4159
                records[uint64(k)] = v
×
4160
        }
×
4161

4162
        return records, nil
×
4163
}
4164

4165
// insertChannel inserts a new channel record into the database.
4166
func insertChannel(ctx context.Context, db SQLQueries,
4167
        edge *models.ChannelEdgeInfo) error {
×
4168

×
4169
        // Make sure that at least a "shell" entry for each node is present in
×
4170
        // the nodes table.
×
4171
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4172
        if err != nil {
×
4173
                return fmt.Errorf("unable to create shell node: %w", err)
×
4174
        }
×
4175

4176
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4177
        if err != nil {
×
4178
                return fmt.Errorf("unable to create shell node: %w", err)
×
4179
        }
×
4180

4181
        var capacity sql.NullInt64
×
4182
        if edge.Capacity != 0 {
×
4183
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4184
        }
×
4185

4186
        createParams := sqlc.CreateChannelParams{
×
4187
                Version:     int16(lnwire.GossipVersion1),
×
4188
                Scid:        channelIDToBytes(edge.ChannelID),
×
4189
                NodeID1:     node1DBID,
×
4190
                NodeID2:     node2DBID,
×
4191
                Outpoint:    edge.ChannelPoint.String(),
×
4192
                Capacity:    capacity,
×
4193
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4194
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4195
        }
×
4196

×
4197
        if edge.AuthProof != nil {
×
4198
                proof := edge.AuthProof
×
4199

×
4200
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4201
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4202
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4203
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4204
        }
×
4205

4206
        // Insert the new channel record.
4207
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4208
        if err != nil {
×
4209
                return err
×
4210
        }
×
4211

4212
        // Insert any channel features.
4213
        for feature := range edge.Features.Features() {
×
4214
                err = db.InsertChannelFeature(
×
4215
                        ctx, sqlc.InsertChannelFeatureParams{
×
4216
                                ChannelID:  dbChanID,
×
4217
                                FeatureBit: int32(feature),
×
4218
                        },
×
4219
                )
×
4220
                if err != nil {
×
4221
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4222
                                "feature(%v): %w", dbChanID, feature, err)
×
4223
                }
×
4224
        }
4225

4226
        // Finally, insert any extra TLV fields in the channel announcement.
4227
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4228
        if err != nil {
×
4229
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4230
                        err)
×
4231
        }
×
4232

4233
        for tlvType, value := range extra {
×
4234
                err := db.UpsertChannelExtraType(
×
4235
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4236
                                ChannelID: dbChanID,
×
4237
                                Type:      int64(tlvType),
×
4238
                                Value:     value,
×
4239
                        },
×
4240
                )
×
4241
                if err != nil {
×
4242
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4243
                                "extra signed field(%v): %w", edge.ChannelID,
×
4244
                                tlvType, err)
×
4245
                }
×
4246
        }
4247

4248
        return nil
×
4249
}
4250

4251
// maybeCreateShellNode checks if a shell node entry exists for the
4252
// given public key. If it does not exist, then a new shell node entry is
4253
// created. The ID of the node is returned. A shell node only has a protocol
4254
// version and public key persisted.
4255
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4256
        pubKey route.Vertex) (int64, error) {
×
4257

×
4258
        dbNode, err := db.GetNodeByPubKey(
×
4259
                ctx, sqlc.GetNodeByPubKeyParams{
×
4260
                        PubKey:  pubKey[:],
×
4261
                        Version: int16(lnwire.GossipVersion1),
×
4262
                },
×
4263
        )
×
4264
        // The node exists. Return the ID.
×
4265
        if err == nil {
×
4266
                return dbNode.ID, nil
×
4267
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4268
                return 0, err
×
4269
        }
×
4270

4271
        // Otherwise, the node does not exist, so we create a shell entry for
4272
        // it.
4273
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4274
                Version: int16(lnwire.GossipVersion1),
×
4275
                PubKey:  pubKey[:],
×
4276
        })
×
4277
        if err != nil {
×
4278
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4279
        }
×
4280

4281
        return id, nil
×
4282
}
4283

4284
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4285
// the database. This includes deleting any existing types and then inserting
4286
// the new types.
4287
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4288
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4289

×
4290
        // Delete all existing extra signed fields for the channel policy.
×
4291
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4292
        if err != nil {
×
4293
                return fmt.Errorf("unable to delete "+
×
4294
                        "existing policy extra signed fields for policy %d: %w",
×
4295
                        chanPolicyID, err)
×
4296
        }
×
4297

4298
        // Insert all new extra signed fields for the channel policy.
4299
        for tlvType, value := range extraFields {
×
4300
                err = db.UpsertChanPolicyExtraType(
×
4301
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4302
                                ChannelPolicyID: chanPolicyID,
×
4303
                                Type:            int64(tlvType),
×
4304
                                Value:           value,
×
4305
                        },
×
4306
                )
×
4307
                if err != nil {
×
4308
                        return fmt.Errorf("unable to insert "+
×
4309
                                "channel_policy(%d) extra signed field(%v): %w",
×
4310
                                chanPolicyID, tlvType, err)
×
4311
                }
×
4312
        }
4313

4314
        return nil
×
4315
}
4316

4317
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4318
// provided dbChanRow and also fetches any other required information
4319
// to construct the edge info.
4320
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4321
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4322
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4323

×
4324
        data, err := batchLoadChannelData(
×
4325
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4326
        )
×
4327
        if err != nil {
×
4328
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4329
                        err)
×
4330
        }
×
4331

4332
        return buildEdgeInfoWithBatchData(
×
4333
                cfg.ChainHash, dbChan, node1, node2, data,
×
4334
        )
×
4335
}
4336

4337
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4338
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4339
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4340
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4341

×
4342
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4343
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4344
                        dbChan.Version)
×
4345
        }
×
4346

4347
        // Use pre-loaded features and extras types.
4348
        fv := lnwire.EmptyFeatureVector()
×
4349
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4350
                for _, bit := range features {
×
4351
                        fv.Set(lnwire.FeatureBit(bit))
×
4352
                }
×
4353
        }
4354

4355
        var extras map[uint64][]byte
×
4356
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4357
        if exists {
×
4358
                extras = channelExtras
×
4359
        } else {
×
4360
                extras = make(map[uint64][]byte)
×
4361
        }
×
4362

4363
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4364
        if err != nil {
×
4365
                return nil, err
×
4366
        }
×
4367

4368
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4369
        if err != nil {
×
4370
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4371
                        "fields: %w", err)
×
4372
        }
×
4373
        if recs == nil {
×
4374
                recs = make([]byte, 0)
×
4375
        }
×
4376

4377
        var btcKey1, btcKey2 route.Vertex
×
4378
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4379
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4380

×
4381
        channel := &models.ChannelEdgeInfo{
×
4382
                ChainHash:        chain,
×
4383
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4384
                NodeKey1Bytes:    node1,
×
4385
                NodeKey2Bytes:    node2,
×
4386
                BitcoinKey1Bytes: btcKey1,
×
4387
                BitcoinKey2Bytes: btcKey2,
×
4388
                ChannelPoint:     *op,
×
4389
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4390
                Features:         fv,
×
4391
                ExtraOpaqueData:  recs,
×
4392
        }
×
4393

×
4394
        // We always set all the signatures at the same time, so we can
×
4395
        // safely check if one signature is present to determine if we have the
×
4396
        // rest of the signatures for the auth proof.
×
4397
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4398
                channel.AuthProof = &models.ChannelAuthProof{
×
4399
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4400
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4401
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4402
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4403
                }
×
4404
        }
×
4405

4406
        return channel, nil
×
4407
}
4408

4409
// buildNodeVertices is a helper that converts raw node public keys
4410
// into route.Vertex instances.
4411
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4412
        route.Vertex, error) {
×
4413

×
4414
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4415
        if err != nil {
×
4416
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4417
                        "create vertex from node1 pubkey: %w", err)
×
4418
        }
×
4419

4420
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4421
        if err != nil {
×
4422
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4423
                        "create vertex from node2 pubkey: %w", err)
×
4424
        }
×
4425

4426
        return node1Vertex, node2Vertex, nil
×
4427
}
4428

4429
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4430
// retrieves all the extra info required to build the complete
4431
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4432
// the provided sqlc.GraphChannelPolicy records are nil.
4433
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4434
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4435
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4436
        *models.ChannelEdgePolicy, error) {
×
4437

×
4438
        if dbPol1 == nil && dbPol2 == nil {
×
4439
                return nil, nil, nil
×
4440
        }
×
4441

4442
        var policyIDs = make([]int64, 0, 2)
×
4443
        if dbPol1 != nil {
×
4444
                policyIDs = append(policyIDs, dbPol1.ID)
×
4445
        }
×
4446
        if dbPol2 != nil {
×
4447
                policyIDs = append(policyIDs, dbPol2.ID)
×
4448
        }
×
4449

4450
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4451
        if err != nil {
×
4452
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4453
                        "data: %w", err)
×
4454
        }
×
4455

4456
        pol1, err := buildChanPolicyWithBatchData(
×
4457
                dbPol1, channelID, node2, batchData,
×
4458
        )
×
4459
        if err != nil {
×
4460
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4461
        }
×
4462

4463
        pol2, err := buildChanPolicyWithBatchData(
×
4464
                dbPol2, channelID, node1, batchData,
×
4465
        )
×
4466
        if err != nil {
×
4467
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4468
        }
×
4469

4470
        return pol1, pol2, nil
×
4471
}
4472

4473
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4474
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4475
// then nil is returned for it.
4476
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4477
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4478
        *models.CachedEdgePolicy, error) {
×
4479

×
4480
        var p1, p2 *models.CachedEdgePolicy
×
4481
        if dbPol1 != nil {
×
4482
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4483
                if err != nil {
×
4484
                        return nil, nil, err
×
4485
                }
×
4486

4487
                p1 = models.NewCachedPolicy(policy1)
×
4488
        }
4489
        if dbPol2 != nil {
×
4490
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4491
                if err != nil {
×
4492
                        return nil, nil, err
×
4493
                }
×
4494

4495
                p2 = models.NewCachedPolicy(policy2)
×
4496
        }
4497

4498
        return p1, p2, nil
×
4499
}
4500

4501
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4502
// provided sqlc.GraphChannelPolicy and other required information.
4503
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4504
        extras map[uint64][]byte,
4505
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4506

×
4507
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4508
        if err != nil {
×
4509
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4510
                        "fields: %w", err)
×
4511
        }
×
4512

4513
        var inboundFee fn.Option[lnwire.Fee]
×
4514
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4515
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4516

×
4517
                inboundFee = fn.Some(lnwire.Fee{
×
4518
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4519
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4520
                })
×
4521
        }
×
4522

4523
        return &models.ChannelEdgePolicy{
×
4524
                SigBytes:  dbPolicy.Signature,
×
4525
                ChannelID: channelID,
×
4526
                LastUpdate: time.Unix(
×
4527
                        dbPolicy.LastUpdate.Int64, 0,
×
4528
                ),
×
4529
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4530
                        dbPolicy.MessageFlags,
×
4531
                ),
×
4532
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4533
                        dbPolicy.ChannelFlags,
×
4534
                ),
×
4535
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4536
                MinHTLC: lnwire.MilliSatoshi(
×
4537
                        dbPolicy.MinHtlcMsat,
×
4538
                ),
×
4539
                MaxHTLC: lnwire.MilliSatoshi(
×
4540
                        dbPolicy.MaxHtlcMsat.Int64,
×
4541
                ),
×
4542
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4543
                        dbPolicy.BaseFeeMsat,
×
4544
                ),
×
4545
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4546
                ToNode:                    toNode,
×
4547
                InboundFee:                inboundFee,
×
4548
                ExtraOpaqueData:           recs,
×
4549
        }, nil
×
4550
}
4551

4552
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4553
// row which is expected to be a sqlc type that contains channel policy
4554
// information. It returns two policies, which may be nil if the policy
4555
// information is not present in the row.
4556
//
4557
//nolint:ll,dupl,funlen
4558
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4559
        *sqlc.GraphChannelPolicy, error) {
×
4560

×
4561
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4562
        switch r := row.(type) {
×
4563
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4564
                if r.Policy1Timelock.Valid {
×
4565
                        policy1 = &sqlc.GraphChannelPolicy{
×
4566
                                Timelock:                r.Policy1Timelock.Int32,
×
4567
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4568
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4569
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4570
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4571
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4572
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4573
                                Disabled:                r.Policy1Disabled,
×
4574
                                MessageFlags:            r.Policy1MessageFlags,
×
4575
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4576
                        }
×
4577
                }
×
4578
                if r.Policy2Timelock.Valid {
×
4579
                        policy2 = &sqlc.GraphChannelPolicy{
×
4580
                                Timelock:                r.Policy2Timelock.Int32,
×
4581
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4582
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4583
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4584
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4585
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4586
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4587
                                Disabled:                r.Policy2Disabled,
×
4588
                                MessageFlags:            r.Policy2MessageFlags,
×
4589
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4590
                        }
×
4591
                }
×
4592

4593
                return policy1, policy2, nil
×
4594

4595
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4596
                if r.Policy1ID.Valid {
×
4597
                        policy1 = &sqlc.GraphChannelPolicy{
×
4598
                                ID:                      r.Policy1ID.Int64,
×
4599
                                Version:                 r.Policy1Version.Int16,
×
4600
                                ChannelID:               r.GraphChannel.ID,
×
4601
                                NodeID:                  r.Policy1NodeID.Int64,
×
4602
                                Timelock:                r.Policy1Timelock.Int32,
×
4603
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4604
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4605
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4606
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4607
                                LastUpdate:              r.Policy1LastUpdate,
×
4608
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4609
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4610
                                Disabled:                r.Policy1Disabled,
×
4611
                                MessageFlags:            r.Policy1MessageFlags,
×
4612
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4613
                                Signature:               r.Policy1Signature,
×
4614
                        }
×
4615
                }
×
4616
                if r.Policy2ID.Valid {
×
4617
                        policy2 = &sqlc.GraphChannelPolicy{
×
4618
                                ID:                      r.Policy2ID.Int64,
×
4619
                                Version:                 r.Policy2Version.Int16,
×
4620
                                ChannelID:               r.GraphChannel.ID,
×
4621
                                NodeID:                  r.Policy2NodeID.Int64,
×
4622
                                Timelock:                r.Policy2Timelock.Int32,
×
4623
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4624
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4625
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4626
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4627
                                LastUpdate:              r.Policy2LastUpdate,
×
4628
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4629
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4630
                                Disabled:                r.Policy2Disabled,
×
4631
                                MessageFlags:            r.Policy2MessageFlags,
×
4632
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4633
                                Signature:               r.Policy2Signature,
×
4634
                        }
×
4635
                }
×
4636

4637
                return policy1, policy2, nil
×
4638

4639
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4640
                if r.Policy1ID.Valid {
×
4641
                        policy1 = &sqlc.GraphChannelPolicy{
×
4642
                                ID:                      r.Policy1ID.Int64,
×
4643
                                Version:                 r.Policy1Version.Int16,
×
4644
                                ChannelID:               r.GraphChannel.ID,
×
4645
                                NodeID:                  r.Policy1NodeID.Int64,
×
4646
                                Timelock:                r.Policy1Timelock.Int32,
×
4647
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4648
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4649
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4650
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4651
                                LastUpdate:              r.Policy1LastUpdate,
×
4652
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4653
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4654
                                Disabled:                r.Policy1Disabled,
×
4655
                                MessageFlags:            r.Policy1MessageFlags,
×
4656
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4657
                                Signature:               r.Policy1Signature,
×
4658
                        }
×
4659
                }
×
4660
                if r.Policy2ID.Valid {
×
4661
                        policy2 = &sqlc.GraphChannelPolicy{
×
4662
                                ID:                      r.Policy2ID.Int64,
×
4663
                                Version:                 r.Policy2Version.Int16,
×
4664
                                ChannelID:               r.GraphChannel.ID,
×
4665
                                NodeID:                  r.Policy2NodeID.Int64,
×
4666
                                Timelock:                r.Policy2Timelock.Int32,
×
4667
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4668
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4669
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4670
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4671
                                LastUpdate:              r.Policy2LastUpdate,
×
4672
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4673
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4674
                                Disabled:                r.Policy2Disabled,
×
4675
                                MessageFlags:            r.Policy2MessageFlags,
×
4676
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4677
                                Signature:               r.Policy2Signature,
×
4678
                        }
×
4679
                }
×
4680

4681
                return policy1, policy2, nil
×
4682

4683
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4684
                if r.Policy1ID.Valid {
×
4685
                        policy1 = &sqlc.GraphChannelPolicy{
×
4686
                                ID:                      r.Policy1ID.Int64,
×
4687
                                Version:                 r.Policy1Version.Int16,
×
4688
                                ChannelID:               r.GraphChannel.ID,
×
4689
                                NodeID:                  r.Policy1NodeID.Int64,
×
4690
                                Timelock:                r.Policy1Timelock.Int32,
×
4691
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4692
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4693
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4694
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4695
                                LastUpdate:              r.Policy1LastUpdate,
×
4696
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4697
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4698
                                Disabled:                r.Policy1Disabled,
×
4699
                                MessageFlags:            r.Policy1MessageFlags,
×
4700
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4701
                                Signature:               r.Policy1Signature,
×
4702
                        }
×
4703
                }
×
4704
                if r.Policy2ID.Valid {
×
4705
                        policy2 = &sqlc.GraphChannelPolicy{
×
4706
                                ID:                      r.Policy2ID.Int64,
×
4707
                                Version:                 r.Policy2Version.Int16,
×
4708
                                ChannelID:               r.GraphChannel.ID,
×
4709
                                NodeID:                  r.Policy2NodeID.Int64,
×
4710
                                Timelock:                r.Policy2Timelock.Int32,
×
4711
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4712
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4713
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4714
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4715
                                LastUpdate:              r.Policy2LastUpdate,
×
4716
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4717
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4718
                                Disabled:                r.Policy2Disabled,
×
4719
                                MessageFlags:            r.Policy2MessageFlags,
×
4720
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4721
                                Signature:               r.Policy2Signature,
×
4722
                        }
×
4723
                }
×
4724

4725
                return policy1, policy2, nil
×
4726

4727
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4728
                if r.Policy1ID.Valid {
×
4729
                        policy1 = &sqlc.GraphChannelPolicy{
×
4730
                                ID:                      r.Policy1ID.Int64,
×
4731
                                Version:                 r.Policy1Version.Int16,
×
4732
                                ChannelID:               r.GraphChannel.ID,
×
4733
                                NodeID:                  r.Policy1NodeID.Int64,
×
4734
                                Timelock:                r.Policy1Timelock.Int32,
×
4735
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4736
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4737
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4738
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4739
                                LastUpdate:              r.Policy1LastUpdate,
×
4740
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4741
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4742
                                Disabled:                r.Policy1Disabled,
×
4743
                                MessageFlags:            r.Policy1MessageFlags,
×
4744
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4745
                                Signature:               r.Policy1Signature,
×
4746
                        }
×
4747
                }
×
4748
                if r.Policy2ID.Valid {
×
4749
                        policy2 = &sqlc.GraphChannelPolicy{
×
4750
                                ID:                      r.Policy2ID.Int64,
×
4751
                                Version:                 r.Policy2Version.Int16,
×
4752
                                ChannelID:               r.GraphChannel.ID,
×
4753
                                NodeID:                  r.Policy2NodeID.Int64,
×
4754
                                Timelock:                r.Policy2Timelock.Int32,
×
4755
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4756
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4757
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4758
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4759
                                LastUpdate:              r.Policy2LastUpdate,
×
4760
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4761
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4762
                                Disabled:                r.Policy2Disabled,
×
4763
                                MessageFlags:            r.Policy2MessageFlags,
×
4764
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4765
                                Signature:               r.Policy2Signature,
×
4766
                        }
×
4767
                }
×
4768

4769
                return policy1, policy2, nil
×
4770

4771
        case sqlc.ListChannelsForNodeIDsRow:
×
4772
                if r.Policy1ID.Valid {
×
4773
                        policy1 = &sqlc.GraphChannelPolicy{
×
4774
                                ID:                      r.Policy1ID.Int64,
×
4775
                                Version:                 r.Policy1Version.Int16,
×
4776
                                ChannelID:               r.GraphChannel.ID,
×
4777
                                NodeID:                  r.Policy1NodeID.Int64,
×
4778
                                Timelock:                r.Policy1Timelock.Int32,
×
4779
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4780
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4781
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4782
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4783
                                LastUpdate:              r.Policy1LastUpdate,
×
4784
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4785
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4786
                                Disabled:                r.Policy1Disabled,
×
4787
                                MessageFlags:            r.Policy1MessageFlags,
×
4788
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4789
                                Signature:               r.Policy1Signature,
×
4790
                        }
×
4791
                }
×
4792
                if r.Policy2ID.Valid {
×
4793
                        policy2 = &sqlc.GraphChannelPolicy{
×
4794
                                ID:                      r.Policy2ID.Int64,
×
4795
                                Version:                 r.Policy2Version.Int16,
×
4796
                                ChannelID:               r.GraphChannel.ID,
×
4797
                                NodeID:                  r.Policy2NodeID.Int64,
×
4798
                                Timelock:                r.Policy2Timelock.Int32,
×
4799
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4800
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4801
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4802
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4803
                                LastUpdate:              r.Policy2LastUpdate,
×
4804
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4805
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4806
                                Disabled:                r.Policy2Disabled,
×
4807
                                MessageFlags:            r.Policy2MessageFlags,
×
4808
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4809
                                Signature:               r.Policy2Signature,
×
4810
                        }
×
4811
                }
×
4812

4813
                return policy1, policy2, nil
×
4814

4815
        case sqlc.ListChannelsByNodeIDRow:
×
4816
                if r.Policy1ID.Valid {
×
4817
                        policy1 = &sqlc.GraphChannelPolicy{
×
4818
                                ID:                      r.Policy1ID.Int64,
×
4819
                                Version:                 r.Policy1Version.Int16,
×
4820
                                ChannelID:               r.GraphChannel.ID,
×
4821
                                NodeID:                  r.Policy1NodeID.Int64,
×
4822
                                Timelock:                r.Policy1Timelock.Int32,
×
4823
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4824
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4825
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4826
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4827
                                LastUpdate:              r.Policy1LastUpdate,
×
4828
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4829
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4830
                                Disabled:                r.Policy1Disabled,
×
4831
                                MessageFlags:            r.Policy1MessageFlags,
×
4832
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4833
                                Signature:               r.Policy1Signature,
×
4834
                        }
×
4835
                }
×
4836
                if r.Policy2ID.Valid {
×
4837
                        policy2 = &sqlc.GraphChannelPolicy{
×
4838
                                ID:                      r.Policy2ID.Int64,
×
4839
                                Version:                 r.Policy2Version.Int16,
×
4840
                                ChannelID:               r.GraphChannel.ID,
×
4841
                                NodeID:                  r.Policy2NodeID.Int64,
×
4842
                                Timelock:                r.Policy2Timelock.Int32,
×
4843
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4844
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4845
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4846
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4847
                                LastUpdate:              r.Policy2LastUpdate,
×
4848
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4849
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4850
                                Disabled:                r.Policy2Disabled,
×
4851
                                MessageFlags:            r.Policy2MessageFlags,
×
4852
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4853
                                Signature:               r.Policy2Signature,
×
4854
                        }
×
4855
                }
×
4856

4857
                return policy1, policy2, nil
×
4858

4859
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4860
                if r.Policy1ID.Valid {
×
4861
                        policy1 = &sqlc.GraphChannelPolicy{
×
4862
                                ID:                      r.Policy1ID.Int64,
×
4863
                                Version:                 r.Policy1Version.Int16,
×
4864
                                ChannelID:               r.GraphChannel.ID,
×
4865
                                NodeID:                  r.Policy1NodeID.Int64,
×
4866
                                Timelock:                r.Policy1Timelock.Int32,
×
4867
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4868
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4869
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4870
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4871
                                LastUpdate:              r.Policy1LastUpdate,
×
4872
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4873
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4874
                                Disabled:                r.Policy1Disabled,
×
4875
                                MessageFlags:            r.Policy1MessageFlags,
×
4876
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4877
                                Signature:               r.Policy1Signature,
×
4878
                        }
×
4879
                }
×
4880
                if r.Policy2ID.Valid {
×
4881
                        policy2 = &sqlc.GraphChannelPolicy{
×
4882
                                ID:                      r.Policy2ID.Int64,
×
4883
                                Version:                 r.Policy2Version.Int16,
×
4884
                                ChannelID:               r.GraphChannel.ID,
×
4885
                                NodeID:                  r.Policy2NodeID.Int64,
×
4886
                                Timelock:                r.Policy2Timelock.Int32,
×
4887
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4888
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4889
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4890
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4891
                                LastUpdate:              r.Policy2LastUpdate,
×
4892
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4893
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4894
                                Disabled:                r.Policy2Disabled,
×
4895
                                MessageFlags:            r.Policy2MessageFlags,
×
4896
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4897
                                Signature:               r.Policy2Signature,
×
4898
                        }
×
4899
                }
×
4900

4901
                return policy1, policy2, nil
×
4902

4903
        case sqlc.GetChannelsByIDsRow:
×
4904
                if r.Policy1ID.Valid {
×
4905
                        policy1 = &sqlc.GraphChannelPolicy{
×
4906
                                ID:                      r.Policy1ID.Int64,
×
4907
                                Version:                 r.Policy1Version.Int16,
×
4908
                                ChannelID:               r.GraphChannel.ID,
×
4909
                                NodeID:                  r.Policy1NodeID.Int64,
×
4910
                                Timelock:                r.Policy1Timelock.Int32,
×
4911
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4912
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4913
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4914
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4915
                                LastUpdate:              r.Policy1LastUpdate,
×
4916
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4917
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4918
                                Disabled:                r.Policy1Disabled,
×
4919
                                MessageFlags:            r.Policy1MessageFlags,
×
4920
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4921
                                Signature:               r.Policy1Signature,
×
4922
                        }
×
4923
                }
×
4924
                if r.Policy2ID.Valid {
×
4925
                        policy2 = &sqlc.GraphChannelPolicy{
×
4926
                                ID:                      r.Policy2ID.Int64,
×
4927
                                Version:                 r.Policy2Version.Int16,
×
4928
                                ChannelID:               r.GraphChannel.ID,
×
4929
                                NodeID:                  r.Policy2NodeID.Int64,
×
4930
                                Timelock:                r.Policy2Timelock.Int32,
×
4931
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4932
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4933
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4934
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4935
                                LastUpdate:              r.Policy2LastUpdate,
×
4936
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4937
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4938
                                Disabled:                r.Policy2Disabled,
×
4939
                                MessageFlags:            r.Policy2MessageFlags,
×
4940
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4941
                                Signature:               r.Policy2Signature,
×
4942
                        }
×
4943
                }
×
4944

4945
                return policy1, policy2, nil
×
4946

4947
        default:
×
4948
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4949
                        "extractChannelPolicies: %T", r)
×
4950
        }
4951
}
4952

4953
// channelIDToBytes converts a channel ID (SCID) to a byte array
4954
// representation.
4955
func channelIDToBytes(channelID uint64) []byte {
×
4956
        var chanIDB [8]byte
×
4957
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4958

×
4959
        return chanIDB[:]
×
4960
}
×
4961

4962
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4963
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4964
        if len(addresses) == 0 {
×
4965
                return nil, nil
×
4966
        }
×
4967

4968
        result := make([]net.Addr, 0, len(addresses))
×
4969
        for _, addr := range addresses {
×
4970
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4971
                if err != nil {
×
4972
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4973
                                "of type %d: %w", addr.address, addr.addrType,
×
4974
                                err)
×
4975
                }
×
4976
                if netAddr != nil {
×
4977
                        result = append(result, netAddr)
×
4978
                }
×
4979
        }
4980

4981
        // If we have no valid addresses, return nil instead of empty slice.
4982
        if len(result) == 0 {
×
4983
                return nil, nil
×
4984
        }
×
4985

4986
        return result, nil
×
4987
}
4988

4989
// parseAddress parses the given address string based on the address type
4990
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4991
// and opaque addresses.
4992
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4993
        switch addrType {
×
4994
        case addressTypeIPv4:
×
4995
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4996
                if err != nil {
×
4997
                        return nil, err
×
4998
                }
×
4999

5000
                tcp.IP = tcp.IP.To4()
×
5001

×
5002
                return tcp, nil
×
5003

5004
        case addressTypeIPv6:
×
5005
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
5006
                if err != nil {
×
5007
                        return nil, err
×
5008
                }
×
5009

5010
                return tcp, nil
×
5011

5012
        case addressTypeTorV3, addressTypeTorV2:
×
5013
                service, portStr, err := net.SplitHostPort(address)
×
5014
                if err != nil {
×
5015
                        return nil, fmt.Errorf("unable to split tor "+
×
5016
                                "address: %v", address)
×
5017
                }
×
5018

5019
                port, err := strconv.Atoi(portStr)
×
5020
                if err != nil {
×
5021
                        return nil, err
×
5022
                }
×
5023

5024
                return &tor.OnionAddr{
×
5025
                        OnionService: service,
×
5026
                        Port:         port,
×
5027
                }, nil
×
5028

5029
        case addressTypeDNS:
×
5030
                hostname, portStr, err := net.SplitHostPort(address)
×
5031
                if err != nil {
×
5032
                        return nil, fmt.Errorf("unable to split DNS "+
×
5033
                                "address: %v", address)
×
5034
                }
×
5035

5036
                port, err := strconv.Atoi(portStr)
×
5037
                if err != nil {
×
5038
                        return nil, err
×
5039
                }
×
5040

5041
                return &lnwire.DNSAddress{
×
5042
                        Hostname: hostname,
×
5043
                        Port:     uint16(port),
×
5044
                }, nil
×
5045

5046
        case addressTypeOpaque:
×
5047
                opaque, err := hex.DecodeString(address)
×
5048
                if err != nil {
×
5049
                        return nil, fmt.Errorf("unable to decode opaque "+
×
5050
                                "address: %v", address)
×
5051
                }
×
5052

5053
                return &lnwire.OpaqueAddrs{
×
5054
                        Payload: opaque,
×
5055
                }, nil
×
5056

5057
        default:
×
5058
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
5059
        }
5060
}
5061

5062
// batchNodeData holds all the related data for a batch of nodes.
5063
type batchNodeData struct {
5064
        // features is a map from a DB node ID to the feature bits for that
5065
        // node.
5066
        features map[int64][]int
5067

5068
        // addresses is a map from a DB node ID to the node's addresses.
5069
        addresses map[int64][]nodeAddress
5070

5071
        // extraFields is a map from a DB node ID to the extra signed fields
5072
        // for that node.
5073
        extraFields map[int64]map[uint64][]byte
5074
}
5075

5076
// nodeAddress holds the address type, position and address string for a
5077
// node. This is used to batch the fetching of node addresses.
5078
type nodeAddress struct {
5079
        addrType dbAddressType
5080
        position int32
5081
        address  string
5082
}
5083

5084
// batchLoadNodeData loads all related data for a batch of node IDs using the
5085
// provided SQLQueries interface. It returns a batchNodeData instance containing
5086
// the node features, addresses and extra signed fields.
5087
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
5088
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
5089

×
5090
        // Batch load the node features.
×
5091
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
5092
        if err != nil {
×
5093
                return nil, fmt.Errorf("unable to batch load node "+
×
5094
                        "features: %w", err)
×
5095
        }
×
5096

5097
        // Batch load the node addresses.
5098
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
5099
        if err != nil {
×
5100
                return nil, fmt.Errorf("unable to batch load node "+
×
5101
                        "addresses: %w", err)
×
5102
        }
×
5103

5104
        // Batch load the node extra signed fields.
5105
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
5106
        if err != nil {
×
5107
                return nil, fmt.Errorf("unable to batch load node extra "+
×
5108
                        "signed fields: %w", err)
×
5109
        }
×
5110

5111
        return &batchNodeData{
×
5112
                features:    features,
×
5113
                addresses:   addrs,
×
5114
                extraFields: extraTypes,
×
5115
        }, nil
×
5116
}
5117

5118
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5119
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5120
func batchLoadNodeFeaturesHelper(ctx context.Context,
5121
        cfg *sqldb.QueryConfig, db SQLQueries,
5122
        nodeIDs []int64) (map[int64][]int, error) {
×
5123

×
5124
        features := make(map[int64][]int)
×
5125

×
5126
        return features, sqldb.ExecuteBatchQuery(
×
5127
                ctx, cfg, nodeIDs,
×
5128
                func(id int64) int64 {
×
5129
                        return id
×
5130
                },
×
5131
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5132
                        error) {
×
5133

×
5134
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5135
                },
×
5136
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5137
                        features[feature.NodeID] = append(
×
5138
                                features[feature.NodeID],
×
5139
                                int(feature.FeatureBit),
×
5140
                        )
×
5141

×
5142
                        return nil
×
5143
                },
×
5144
        )
5145
}
5146

5147
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5148
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5149
// node ID to a slice of nodeAddress structs.
5150
func batchLoadNodeAddressesHelper(ctx context.Context,
5151
        cfg *sqldb.QueryConfig, db SQLQueries,
5152
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5153

×
5154
        addrs := make(map[int64][]nodeAddress)
×
5155

×
5156
        return addrs, sqldb.ExecuteBatchQuery(
×
5157
                ctx, cfg, nodeIDs,
×
5158
                func(id int64) int64 {
×
5159
                        return id
×
5160
                },
×
5161
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5162
                        error) {
×
5163

×
5164
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5165
                },
×
5166
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5167
                        addrs[addr.NodeID] = append(
×
5168
                                addrs[addr.NodeID], nodeAddress{
×
5169
                                        addrType: dbAddressType(addr.Type),
×
5170
                                        position: addr.Position,
×
5171
                                        address:  addr.Address,
×
5172
                                },
×
5173
                        )
×
5174

×
5175
                        return nil
×
5176
                },
×
5177
        )
5178
}
5179

5180
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5181
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5182
// query.
5183
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5184
        cfg *sqldb.QueryConfig, db SQLQueries,
5185
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5186

×
5187
        extraFields := make(map[int64]map[uint64][]byte)
×
5188

×
5189
        callback := func(ctx context.Context,
×
5190
                field sqlc.GraphNodeExtraType) error {
×
5191

×
5192
                if extraFields[field.NodeID] == nil {
×
5193
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5194
                }
×
5195
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5196

×
5197
                return nil
×
5198
        }
5199

5200
        return extraFields, sqldb.ExecuteBatchQuery(
×
5201
                ctx, cfg, nodeIDs,
×
5202
                func(id int64) int64 {
×
5203
                        return id
×
5204
                },
×
5205
                func(ctx context.Context, ids []int64) (
5206
                        []sqlc.GraphNodeExtraType, error) {
×
5207

×
5208
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5209
                },
×
5210
                callback,
5211
        )
5212
}
5213

5214
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5215
// from the provided sqlc.GraphChannelPolicy records and the
5216
// provided batchChannelData.
5217
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5218
        channelID uint64, node1, node2 route.Vertex,
5219
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5220
        *models.ChannelEdgePolicy, error) {
×
5221

×
5222
        pol1, err := buildChanPolicyWithBatchData(
×
5223
                dbPol1, channelID, node2, batchData,
×
5224
        )
×
5225
        if err != nil {
×
5226
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5227
        }
×
5228

5229
        pol2, err := buildChanPolicyWithBatchData(
×
5230
                dbPol2, channelID, node1, batchData,
×
5231
        )
×
5232
        if err != nil {
×
5233
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5234
        }
×
5235

5236
        return pol1, pol2, nil
×
5237
}
5238

5239
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5240
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5241
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5242
        channelID uint64, toNode route.Vertex,
5243
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5244

×
5245
        if dbPol == nil {
×
5246
                return nil, nil
×
5247
        }
×
5248

5249
        var dbPol1Extras map[uint64][]byte
×
5250
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5251
                dbPol1Extras = extras
×
5252
        } else {
×
5253
                dbPol1Extras = make(map[uint64][]byte)
×
5254
        }
×
5255

5256
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5257
}
5258

5259
// batchChannelData holds all the related data for a batch of channels.
5260
type batchChannelData struct {
5261
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5262
        chanfeatures map[int64][]int
5263

5264
        // chanExtras is a map from DB channel ID to a map of TLV type to
5265
        // extra signed field bytes.
5266
        chanExtraTypes map[int64]map[uint64][]byte
5267

5268
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5269
        // to extra signed field bytes.
5270
        policyExtras map[int64]map[uint64][]byte
5271
}
5272

5273
// batchLoadChannelData loads all related data for batches of channels and
5274
// policies.
5275
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5276
        db SQLQueries, channelIDs []int64,
5277
        policyIDs []int64) (*batchChannelData, error) {
×
5278

×
5279
        batchData := &batchChannelData{
×
5280
                chanfeatures:   make(map[int64][]int),
×
5281
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5282
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5283
        }
×
5284

×
5285
        // Batch load channel features and extras
×
5286
        var err error
×
5287
        if len(channelIDs) > 0 {
×
5288
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5289
                        ctx, cfg, db, channelIDs,
×
5290
                )
×
5291
                if err != nil {
×
5292
                        return nil, fmt.Errorf("unable to batch load "+
×
5293
                                "channel features: %w", err)
×
5294
                }
×
5295

5296
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5297
                        ctx, cfg, db, channelIDs,
×
5298
                )
×
5299
                if err != nil {
×
5300
                        return nil, fmt.Errorf("unable to batch load "+
×
5301
                                "channel extras: %w", err)
×
5302
                }
×
5303
        }
5304

5305
        if len(policyIDs) > 0 {
×
5306
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5307
                        ctx, cfg, db, policyIDs,
×
5308
                )
×
5309
                if err != nil {
×
5310
                        return nil, fmt.Errorf("unable to batch load "+
×
5311
                                "policy extras: %w", err)
×
5312
                }
×
5313
                batchData.policyExtras = policyExtras
×
5314
        }
5315

5316
        return batchData, nil
×
5317
}
5318

5319
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5320
// channel IDs using ExecuteBatchQuery wrapper around the
5321
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5322
// slice of feature bits.
5323
func batchLoadChannelFeaturesHelper(ctx context.Context,
5324
        cfg *sqldb.QueryConfig, db SQLQueries,
5325
        channelIDs []int64) (map[int64][]int, error) {
×
5326

×
5327
        features := make(map[int64][]int)
×
5328

×
5329
        return features, sqldb.ExecuteBatchQuery(
×
5330
                ctx, cfg, channelIDs,
×
5331
                func(id int64) int64 {
×
5332
                        return id
×
5333
                },
×
5334
                func(ctx context.Context,
5335
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5336

×
5337
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5338
                },
×
5339
                func(ctx context.Context,
5340
                        feature sqlc.GraphChannelFeature) error {
×
5341

×
5342
                        features[feature.ChannelID] = append(
×
5343
                                features[feature.ChannelID],
×
5344
                                int(feature.FeatureBit),
×
5345
                        )
×
5346

×
5347
                        return nil
×
5348
                },
×
5349
        )
5350
}
5351

5352
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5353
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5354
// query. It returns a map from DB channel ID to a map of TLV type to extra
5355
// signed field bytes.
5356
func batchLoadChannelExtrasHelper(ctx context.Context,
5357
        cfg *sqldb.QueryConfig, db SQLQueries,
5358
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5359

×
5360
        extras := make(map[int64]map[uint64][]byte)
×
5361

×
5362
        cb := func(ctx context.Context,
×
5363
                extra sqlc.GraphChannelExtraType) error {
×
5364

×
5365
                if extras[extra.ChannelID] == nil {
×
5366
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5367
                }
×
5368
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5369

×
5370
                return nil
×
5371
        }
5372

5373
        return extras, sqldb.ExecuteBatchQuery(
×
5374
                ctx, cfg, channelIDs,
×
5375
                func(id int64) int64 {
×
5376
                        return id
×
5377
                },
×
5378
                func(ctx context.Context,
5379
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5380

×
5381
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5382
                }, cb,
×
5383
        )
5384
}
5385

5386
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5387
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5388
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5389
// a map of TLV type to extra signed field bytes.
5390
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5391
        cfg *sqldb.QueryConfig, db SQLQueries,
5392
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5393

×
5394
        extras := make(map[int64]map[uint64][]byte)
×
5395

×
5396
        return extras, sqldb.ExecuteBatchQuery(
×
5397
                ctx, cfg, policyIDs,
×
5398
                func(id int64) int64 {
×
5399
                        return id
×
5400
                },
×
5401
                func(ctx context.Context, ids []int64) (
5402
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5403

×
5404
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5405
                },
×
5406
                func(ctx context.Context,
5407
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5408

×
5409
                        if extras[row.PolicyID] == nil {
×
5410
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5411
                        }
×
5412
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5413

×
5414
                        return nil
×
5415
                },
5416
        )
5417
}
5418

5419
// forEachNodePaginated executes a paginated query to process each node in the
5420
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5421
// and applies the provided processNode function to each node.
5422
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5423
        db SQLQueries, protocol lnwire.GossipVersion,
5424
        processNode func(context.Context, int64,
5425
                *models.Node) error) error {
×
5426

×
5427
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5428
                limit int32) ([]sqlc.GraphNode, error) {
×
5429

×
5430
                return db.ListNodesPaginated(
×
5431
                        ctx, sqlc.ListNodesPaginatedParams{
×
5432
                                Version: int16(protocol),
×
5433
                                ID:      lastID,
×
5434
                                Limit:   limit,
×
5435
                        },
×
5436
                )
×
5437
        }
×
5438

5439
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5440
                return node.ID
×
5441
        }
×
5442

5443
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5444
                return node.ID, nil
×
5445
        }
×
5446

5447
        batchQueryFunc := func(ctx context.Context,
×
5448
                nodeIDs []int64) (*batchNodeData, error) {
×
5449

×
5450
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5451
        }
×
5452

5453
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5454
                batchData *batchNodeData) error {
×
5455

×
5456
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5457
                if err != nil {
×
5458
                        return fmt.Errorf("unable to build "+
×
5459
                                "node(id=%d): %w", dbNode.ID, err)
×
5460
                }
×
5461

5462
                return processNode(ctx, dbNode.ID, node)
×
5463
        }
5464

5465
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5466
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5467
                collectFunc, batchQueryFunc, processItem,
×
5468
        )
×
5469
}
5470

5471
// forEachChannelWithPolicies executes a paginated query to process each channel
5472
// with policies in the graph.
5473
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5474
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5475
                *models.ChannelEdgePolicy,
5476
                *models.ChannelEdgePolicy) error) error {
×
5477

×
5478
        type channelBatchIDs struct {
×
5479
                channelID int64
×
5480
                policyIDs []int64
×
5481
        }
×
5482

×
5483
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5484
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5485
                error) {
×
5486

×
5487
                return db.ListChannelsWithPoliciesPaginated(
×
5488
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5489
                                Version: int16(lnwire.GossipVersion1),
×
5490
                                ID:      lastID,
×
5491
                                Limit:   limit,
×
5492
                        },
×
5493
                )
×
5494
        }
×
5495

5496
        extractPageCursor := func(
×
5497
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5498

×
5499
                return row.GraphChannel.ID
×
5500
        }
×
5501

5502
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5503
                channelBatchIDs, error) {
×
5504

×
5505
                ids := channelBatchIDs{
×
5506
                        channelID: row.GraphChannel.ID,
×
5507
                }
×
5508

×
5509
                // Extract policy IDs from the row.
×
5510
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5511
                if err != nil {
×
5512
                        return ids, err
×
5513
                }
×
5514

5515
                if dbPol1 != nil {
×
5516
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5517
                }
×
5518
                if dbPol2 != nil {
×
5519
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5520
                }
×
5521

5522
                return ids, nil
×
5523
        }
5524

5525
        batchDataFunc := func(ctx context.Context,
×
5526
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5527

×
5528
                // Separate channel IDs from policy IDs.
×
5529
                var (
×
5530
                        channelIDs = make([]int64, len(allIDs))
×
5531
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5532
                )
×
5533

×
5534
                for i, ids := range allIDs {
×
5535
                        channelIDs[i] = ids.channelID
×
5536
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5537
                }
×
5538

5539
                return batchLoadChannelData(
×
5540
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5541
                )
×
5542
        }
5543

5544
        processItem := func(ctx context.Context,
×
5545
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5546
                batchData *batchChannelData) error {
×
5547

×
5548
                node1, node2, err := buildNodeVertices(
×
5549
                        row.Node1Pubkey, row.Node2Pubkey,
×
5550
                )
×
5551
                if err != nil {
×
5552
                        return err
×
5553
                }
×
5554

5555
                edge, err := buildEdgeInfoWithBatchData(
×
5556
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5557
                        batchData,
×
5558
                )
×
5559
                if err != nil {
×
5560
                        return fmt.Errorf("unable to build channel info: %w",
×
5561
                                err)
×
5562
                }
×
5563

5564
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5565
                if err != nil {
×
5566
                        return err
×
5567
                }
×
5568

5569
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5570
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5571
                )
×
5572
                if err != nil {
×
5573
                        return err
×
5574
                }
×
5575

5576
                return processChannel(edge, p1, p2)
×
5577
        }
5578

5579
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5580
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5581
                collectFunc, batchDataFunc, processItem,
×
5582
        )
×
5583
}
5584

5585
// buildDirectedChannel builds a DirectedChannel instance from the provided
5586
// data.
5587
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5588
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5589
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5590
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5591

×
5592
        node1, node2, err := buildNodeVertices(
×
5593
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5594
        )
×
5595
        if err != nil {
×
5596
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5597
        }
×
5598

5599
        edge, err := buildEdgeInfoWithBatchData(
×
5600
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5601
        )
×
5602
        if err != nil {
×
5603
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5604
        }
×
5605

5606
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5607
        if err != nil {
×
5608
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5609
                        err)
×
5610
        }
×
5611

5612
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5613
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5614
                channelBatchData,
×
5615
        )
×
5616
        if err != nil {
×
5617
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5618
                        err)
×
5619
        }
×
5620

5621
        // Determine outgoing and incoming policy for this specific node.
5622
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5623
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5624
        outPolicy, inPolicy := p1, p2
×
5625
        if (p1 != nil && p1ToNode == nodeID) ||
×
5626
                (p2 != nil && p2ToNode != nodeID) {
×
5627

×
5628
                outPolicy, inPolicy = p2, p1
×
5629
        }
×
5630

5631
        // Build cached policy.
5632
        var cachedInPolicy *models.CachedEdgePolicy
×
5633
        if inPolicy != nil {
×
5634
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5635
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5636
                cachedInPolicy.ToNodeFeatures = features
×
5637
        }
×
5638

5639
        // Extract inbound fee.
5640
        var inboundFee lnwire.Fee
×
5641
        if outPolicy != nil {
×
5642
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5643
                        inboundFee = fee
×
5644
                })
×
5645
        }
5646

5647
        // Build directed channel.
5648
        directedChannel := &DirectedChannel{
×
5649
                ChannelID:    edge.ChannelID,
×
5650
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5651
                OtherNode:    edge.NodeKey2Bytes,
×
5652
                Capacity:     edge.Capacity,
×
5653
                OutPolicySet: outPolicy != nil,
×
5654
                InPolicy:     cachedInPolicy,
×
5655
                InboundFee:   inboundFee,
×
5656
        }
×
5657

×
5658
        if nodePub == edge.NodeKey2Bytes {
×
5659
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5660
        }
×
5661

5662
        return directedChannel, nil
×
5663
}
5664

5665
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5666
// provided rows. It uses batch loading for channels, policies, and nodes.
5667
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5668
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5669

×
5670
        var (
×
5671
                channelIDs = make([]int64, len(rows))
×
5672
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5673
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5674

×
5675
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5676
                nodeIDSet = make(map[int64]bool)
×
5677

×
5678
                // edges will hold the final channel edges built from the rows.
×
5679
                edges = make([]ChannelEdge, 0, len(rows))
×
5680
        )
×
5681

×
5682
        // Collect all IDs needed for batch loading.
×
5683
        for i, row := range rows {
×
5684
                channelIDs[i] = row.Channel().ID
×
5685

×
5686
                // Collect policy IDs
×
5687
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5688
                if err != nil {
×
5689
                        return nil, fmt.Errorf("unable to extract channel "+
×
5690
                                "policies: %w", err)
×
5691
                }
×
5692
                if dbPol1 != nil {
×
5693
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5694
                }
×
5695
                if dbPol2 != nil {
×
5696
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5697
                }
×
5698

5699
                var (
×
5700
                        node1ID = row.Node1().ID
×
5701
                        node2ID = row.Node2().ID
×
5702
                )
×
5703

×
5704
                // Collect unique node IDs.
×
5705
                if !nodeIDSet[node1ID] {
×
5706
                        nodeIDs = append(nodeIDs, node1ID)
×
5707
                        nodeIDSet[node1ID] = true
×
5708
                }
×
5709

5710
                if !nodeIDSet[node2ID] {
×
5711
                        nodeIDs = append(nodeIDs, node2ID)
×
5712
                        nodeIDSet[node2ID] = true
×
5713
                }
×
5714
        }
5715

5716
        // Batch the data for all the channels and policies.
5717
        channelBatchData, err := batchLoadChannelData(
×
5718
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5719
        )
×
5720
        if err != nil {
×
5721
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5722
                        "policy data: %w", err)
×
5723
        }
×
5724

5725
        // Batch the data for all the nodes.
5726
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5727
        if err != nil {
×
5728
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5729
                        err)
×
5730
        }
×
5731

5732
        // Build all channel edges using batch data.
5733
        for _, row := range rows {
×
5734
                // Build nodes using batch data.
×
5735
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5736
                if err != nil {
×
5737
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5738
                }
×
5739

5740
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5741
                if err != nil {
×
5742
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5743
                }
×
5744

5745
                // Build channel info using batch data.
5746
                channel, err := buildEdgeInfoWithBatchData(
×
5747
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5748
                        node2.PubKeyBytes, channelBatchData,
×
5749
                )
×
5750
                if err != nil {
×
5751
                        return nil, fmt.Errorf("unable to build channel "+
×
5752
                                "info: %w", err)
×
5753
                }
×
5754

5755
                // Extract and build policies using batch data.
5756
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5757
                if err != nil {
×
5758
                        return nil, fmt.Errorf("unable to extract channel "+
×
5759
                                "policies: %w", err)
×
5760
                }
×
5761

5762
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5763
                        dbPol1, dbPol2, channel.ChannelID,
×
5764
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5765
                )
×
5766
                if err != nil {
×
5767
                        return nil, fmt.Errorf("unable to build channel "+
×
5768
                                "policies: %w", err)
×
5769
                }
×
5770

5771
                edges = append(edges, ChannelEdge{
×
5772
                        Info:    channel,
×
5773
                        Policy1: p1,
×
5774
                        Policy2: p2,
×
5775
                        Node1:   node1,
×
5776
                        Node2:   node2,
×
5777
                })
×
5778
        }
5779

5780
        return edges, nil
×
5781
}
5782

5783
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5784
// instances from the provided rows using batch loading for channel data.
5785
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5786
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5787
        []*models.ChannelEdgeInfo, []int64, error) {
×
5788

×
5789
        if len(rows) == 0 {
×
5790
                return nil, nil, nil
×
5791
        }
×
5792

5793
        // Collect all the channel IDs needed for batch loading.
5794
        channelIDs := make([]int64, len(rows))
×
5795
        for i, row := range rows {
×
5796
                channelIDs[i] = row.Channel().ID
×
5797
        }
×
5798

5799
        // Batch load the channel data.
5800
        channelBatchData, err := batchLoadChannelData(
×
5801
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5802
        )
×
5803
        if err != nil {
×
5804
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5805
                        "data: %w", err)
×
5806
        }
×
5807

5808
        // Build all channel edges using batch data.
5809
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5810
        for _, row := range rows {
×
5811
                node1, node2, err := buildNodeVertices(
×
5812
                        row.Node1Pub(), row.Node2Pub(),
×
5813
                )
×
5814
                if err != nil {
×
5815
                        return nil, nil, err
×
5816
                }
×
5817

5818
                // Build channel info using batch data
5819
                info, err := buildEdgeInfoWithBatchData(
×
5820
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5821
                        channelBatchData,
×
5822
                )
×
5823
                if err != nil {
×
5824
                        return nil, nil, err
×
5825
                }
×
5826

5827
                edges = append(edges, info)
×
5828
        }
5829

5830
        return edges, channelIDs, nil
×
5831
}
5832

5833
// handleZombieMarking is a helper function that handles the logic of
5834
// marking a channel as a zombie in the database. It takes into account whether
5835
// we are in strict zombie pruning mode, and adjusts the node public keys
5836
// accordingly based on the last update timestamps of the channel policies.
5837
func handleZombieMarking(ctx context.Context, db SQLQueries,
5838
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5839
        strictZombiePruning bool, scid uint64) error {
×
5840

×
5841
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5842

×
5843
        if strictZombiePruning {
×
5844
                var e1UpdateTime, e2UpdateTime *time.Time
×
5845
                if row.Policy1LastUpdate.Valid {
×
5846
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5847
                        e1UpdateTime = &e1Time
×
5848
                }
×
5849
                if row.Policy2LastUpdate.Valid {
×
5850
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5851
                        e2UpdateTime = &e2Time
×
5852
                }
×
5853

5854
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5855
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5856
                        e2UpdateTime,
×
5857
                )
×
5858
        }
5859

5860
        return db.UpsertZombieChannel(
×
5861
                ctx, sqlc.UpsertZombieChannelParams{
×
5862
                        Version:  int16(lnwire.GossipVersion1),
×
5863
                        Scid:     channelIDToBytes(scid),
×
5864
                        NodeKey1: nodeKey1[:],
×
5865
                        NodeKey2: nodeKey2[:],
×
5866
                },
×
5867
        )
×
5868
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc