• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 19316446305

13 Nov 2025 12:34AM UTC coverage: 65.219% (+8.3%) from 56.89%
19316446305

push

github

web-flow
Merge pull request #10343 from lightningnetwork/0-21-0-staging

Merge branch `0-21-staging`

361 of 5339 new or added lines in 47 files covered. (6.76%)

34 existing lines in 8 files now uncovered.

137571 of 210938 relevant lines covered (65.22%)

20832.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
46
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
47
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
48
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
49
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
50
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
51
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
52
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
53
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
54
        DeleteNode(ctx context.Context, id int64) error
55

56
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
57
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
58
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
59
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
60

61
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
62
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
63
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
64
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
65

66
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
67
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
68
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
69
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
70
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
71

72
        /*
73
                Source node queries.
74
        */
75
        AddSourceNode(ctx context.Context, nodeID int64) error
76
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
77

78
        /*
79
                Channel queries.
80
        */
81
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
82
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
83
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
84
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
85
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
86
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
87
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
88
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
89
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
90
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
91
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
92
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
93
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
94
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
95
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
96
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
97
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
98
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
99
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
100
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
101
        DeleteChannels(ctx context.Context, ids []int64) error
102

103
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
104
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
105
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
106
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
107

108
        /*
109
                Channel Policy table queries.
110
        */
111
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
112
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
113
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
114

115
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
116
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
117
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
118

119
        /*
120
                Zombie index queries.
121
        */
122
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
123
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
124
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
125
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
126
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
127
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
128

129
        /*
130
                Prune log table queries.
131
        */
132
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
133
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
134
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
135
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
136
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
137

138
        /*
139
                Closed SCID table queries.
140
        */
141
        InsertClosedChannel(ctx context.Context, scid []byte) error
142
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
143
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
144

145
        /*
146
                Migration specific queries.
147

148
                NOTE: these should not be used in code other than migrations.
149
                Once sqldbv2 is in place, these can be removed from this struct
150
                as then migrations will have their own dedicated queries
151
                structs.
152
        */
153
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
154
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
155
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
156
}
157

158
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
159
// database operations.
160
type BatchedSQLQueries interface {
161
        SQLQueries
162
        sqldb.BatchedTx[SQLQueries]
163
}
164

165
// SQLStore is an implementation of the V1Store interface that uses a SQL
166
// database as the backend.
167
type SQLStore struct {
168
        cfg *SQLStoreConfig
169
        db  BatchedSQLQueries
170

171
        // cacheMu guards all caches (rejectCache and chanCache). If
172
        // this mutex will be acquired at the same time as the DB mutex then
173
        // the cacheMu MUST be acquired first to prevent deadlock.
174
        cacheMu     sync.RWMutex
175
        rejectCache *rejectCache
176
        chanCache   *channelCache
177

178
        chanScheduler batch.Scheduler[SQLQueries]
179
        nodeScheduler batch.Scheduler[SQLQueries]
180

181
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
182
        srcNodeMu sync.Mutex
183
}
184

185
// A compile-time assertion to ensure that SQLStore implements the V1Store
186
// interface.
187
var _ V1Store = (*SQLStore)(nil)
188

189
// SQLStoreConfig holds the configuration for the SQLStore.
190
type SQLStoreConfig struct {
191
        // ChainHash is the genesis hash for the chain that all the gossip
192
        // messages in this store are aimed at.
193
        ChainHash chainhash.Hash
194

195
        // QueryConfig holds configuration values for SQL queries.
196
        QueryCfg *sqldb.QueryConfig
197
}
198

199
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
200
// storage backend.
201
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
202
        options ...StoreOptionModifier) (*SQLStore, error) {
×
203

×
204
        opts := DefaultOptions()
×
205
        for _, o := range options {
×
206
                o(opts)
×
207
        }
×
208

209
        if opts.NoMigration {
×
210
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
211
                        "supported for SQL stores")
×
212
        }
×
213

214
        s := &SQLStore{
×
215
                cfg:         cfg,
×
216
                db:          db,
×
217
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
218
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
NEW
219
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
220
        }
×
221

×
222
        s.chanScheduler = batch.NewTimeScheduler(
×
223
                db, &s.cacheMu, opts.BatchCommitInterval,
×
224
        )
×
225
        s.nodeScheduler = batch.NewTimeScheduler(
×
226
                db, nil, opts.BatchCommitInterval,
×
227
        )
×
228

×
229
        return s, nil
×
230
}
231

232
// AddNode adds a vertex/node to the graph database. If the node is not
233
// in the database from before, this will add a new, unconnected one to the
234
// graph. If it is present from before, this will update that node's
235
// information.
236
//
237
// NOTE: part of the V1Store interface.
238
func (s *SQLStore) AddNode(ctx context.Context,
239
        node *models.Node, opts ...batch.SchedulerOption) error {
×
240

×
241
        r := &batch.Request[SQLQueries]{
×
242
                Opts: batch.NewSchedulerOptions(opts...),
×
243
                Do: func(queries SQLQueries) error {
×
244
                        _, err := upsertNode(ctx, queries, node)
×
245

×
246
                        // It is possible that two of the same node
×
247
                        // announcements are both being processed in the same
×
248
                        // batch. This may case the UpsertNode conflict to
×
249
                        // be hit since we require at the db layer that the
×
250
                        // new last_update is greater than the existing
×
251
                        // last_update. We need to gracefully handle this here.
×
252
                        if errors.Is(err, sql.ErrNoRows) {
×
253
                                return nil
×
254
                        }
×
255

256
                        return err
×
257
                },
258
        }
259

260
        return s.nodeScheduler.Execute(ctx, r)
×
261
}
262

263
// FetchNode attempts to look up a target node by its identity public
264
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
265
// returned.
266
//
267
// NOTE: part of the V1Store interface.
268
func (s *SQLStore) FetchNode(ctx context.Context,
269
        pubKey route.Vertex) (*models.Node, error) {
×
270

×
271
        var node *models.Node
×
272
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
273
                var err error
×
274
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
275

×
276
                return err
×
277
        }, sqldb.NoOpReset)
×
278
        if err != nil {
×
279
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
280
        }
×
281

282
        return node, nil
×
283
}
284

285
// HasNode determines if the graph has a vertex identified by the
286
// target node identity public key. If the node exists in the database, a
287
// timestamp of when the data for the node was lasted updated is returned along
288
// with a true boolean. Otherwise, an empty time.Time is returned with a false
289
// boolean.
290
//
291
// NOTE: part of the V1Store interface.
292
func (s *SQLStore) HasNode(ctx context.Context,
293
        pubKey [33]byte) (time.Time, bool, error) {
×
294

×
295
        var (
×
296
                exists     bool
×
297
                lastUpdate time.Time
×
298
        )
×
299
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
300
                dbNode, err := db.GetNodeByPubKey(
×
301
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
302
                                Version: int16(lnwire.GossipVersion1),
×
303
                                PubKey:  pubKey[:],
×
304
                        },
×
305
                )
×
306
                if errors.Is(err, sql.ErrNoRows) {
×
307
                        return nil
×
308
                } else if err != nil {
×
309
                        return fmt.Errorf("unable to fetch node: %w", err)
×
310
                }
×
311

312
                exists = true
×
313

×
314
                if dbNode.LastUpdate.Valid {
×
315
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
316
                }
×
317

318
                return nil
×
319
        }, sqldb.NoOpReset)
320
        if err != nil {
×
321
                return time.Time{}, false,
×
322
                        fmt.Errorf("unable to fetch node: %w", err)
×
323
        }
×
324

325
        return lastUpdate, exists, nil
×
326
}
327

328
// AddrsForNode returns all known addresses for the target node public key
329
// that the graph DB is aware of. The returned boolean indicates if the
330
// given node is unknown to the graph DB or not.
331
//
332
// NOTE: part of the V1Store interface.
333
func (s *SQLStore) AddrsForNode(ctx context.Context,
334
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
335

×
336
        var (
×
337
                addresses []net.Addr
×
338
                known     bool
×
339
        )
×
340
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
341
                // First, check if the node exists and get its DB ID if it
×
342
                // does.
×
343
                dbID, err := db.GetNodeIDByPubKey(
×
344
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
345
                                Version: int16(lnwire.GossipVersion1),
×
346
                                PubKey:  nodePub.SerializeCompressed(),
×
347
                        },
×
348
                )
×
349
                if errors.Is(err, sql.ErrNoRows) {
×
350
                        return nil
×
351
                }
×
352

353
                known = true
×
354

×
355
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
356
                if err != nil {
×
357
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
358
                                err)
×
359
                }
×
360

361
                return nil
×
362
        }, sqldb.NoOpReset)
363
        if err != nil {
×
364
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
365
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
366
        }
×
367

368
        return known, addresses, nil
×
369
}
370

371
// DeleteNode starts a new database transaction to remove a vertex/node
372
// from the database according to the node's public key.
373
//
374
// NOTE: part of the V1Store interface.
375
func (s *SQLStore) DeleteNode(ctx context.Context,
376
        pubKey route.Vertex) error {
×
377

×
378
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
379
                res, err := db.DeleteNodeByPubKey(
×
380
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
NEW
381
                                Version: int16(lnwire.GossipVersion1),
×
382
                                PubKey:  pubKey[:],
×
383
                        },
×
384
                )
×
385
                if err != nil {
×
386
                        return err
×
387
                }
×
388

389
                rows, err := res.RowsAffected()
×
390
                if err != nil {
×
391
                        return err
×
392
                }
×
393

394
                if rows == 0 {
×
395
                        return ErrGraphNodeNotFound
×
396
                } else if rows > 1 {
×
397
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
398
                }
×
399

400
                return err
×
401
        }, sqldb.NoOpReset)
402
        if err != nil {
×
403
                return fmt.Errorf("unable to delete node: %w", err)
×
404
        }
×
405

406
        return nil
×
407
}
408

409
// FetchNodeFeatures returns the features of the given node. If no features are
410
// known for the node, an empty feature vector is returned.
411
//
412
// NOTE: this is part of the graphdb.NodeTraverser interface.
413
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
414
        *lnwire.FeatureVector, error) {
×
415

×
416
        ctx := context.TODO()
×
417

×
418
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
419
}
×
420

421
// DisabledChannelIDs returns the channel ids of disabled channels.
422
// A channel is disabled when two of the associated ChanelEdgePolicies
423
// have their disabled bit on.
424
//
425
// NOTE: part of the V1Store interface.
426
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
427
        var (
×
428
                ctx     = context.TODO()
×
429
                chanIDs []uint64
×
430
        )
×
431
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
432
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
433
                if err != nil {
×
434
                        return fmt.Errorf("unable to fetch disabled "+
×
435
                                "channels: %w", err)
×
436
                }
×
437

438
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
439

×
440
                return nil
×
441
        }, sqldb.NoOpReset)
442
        if err != nil {
×
443
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
444
                        err)
×
445
        }
×
446

447
        return chanIDs, nil
×
448
}
449

450
// LookupAlias attempts to return the alias as advertised by the target node.
451
//
452
// NOTE: part of the V1Store interface.
453
func (s *SQLStore) LookupAlias(ctx context.Context,
454
        pub *btcec.PublicKey) (string, error) {
×
455

×
456
        var alias string
×
457
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
458
                dbNode, err := db.GetNodeByPubKey(
×
459
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
460
                                Version: int16(lnwire.GossipVersion1),
×
461
                                PubKey:  pub.SerializeCompressed(),
×
462
                        },
×
463
                )
×
464
                if errors.Is(err, sql.ErrNoRows) {
×
465
                        return ErrNodeAliasNotFound
×
466
                } else if err != nil {
×
467
                        return fmt.Errorf("unable to fetch node: %w", err)
×
468
                }
×
469

470
                if !dbNode.Alias.Valid {
×
471
                        return ErrNodeAliasNotFound
×
472
                }
×
473

474
                alias = dbNode.Alias.String
×
475

×
476
                return nil
×
477
        }, sqldb.NoOpReset)
478
        if err != nil {
×
479
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
480
        }
×
481

482
        return alias, nil
×
483
}
484

485
// SourceNode returns the source node of the graph. The source node is treated
486
// as the center node within a star-graph. This method may be used to kick off
487
// a path finding algorithm in order to explore the reachability of another
488
// node based off the source node.
489
//
490
// NOTE: part of the V1Store interface.
491
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
492
        error) {
×
493

×
494
        var node *models.Node
×
495
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
496
                _, nodePub, err := s.getSourceNode(
×
NEW
497
                        ctx, db, lnwire.GossipVersion1,
×
NEW
498
                )
×
499
                if err != nil {
×
500
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
501
                                err)
×
502
                }
×
503

504
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
505

×
506
                return err
×
507
        }, sqldb.NoOpReset)
508
        if err != nil {
×
509
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
510
        }
×
511

512
        return node, nil
×
513
}
514

515
// SetSourceNode sets the source node within the graph database. The source
516
// node is to be used as the center of a star-graph within path finding
517
// algorithms.
518
//
519
// NOTE: part of the V1Store interface.
520
func (s *SQLStore) SetSourceNode(ctx context.Context,
521
        node *models.Node) error {
×
522

×
523
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
524
                id, err := upsertNode(ctx, db, node)
×
525
                if err != nil {
×
526
                        return fmt.Errorf("unable to upsert source node: %w",
×
527
                                err)
×
528
                }
×
529

530
                // Make sure that if a source node for this version is already
531
                // set, then the ID is the same as the one we are about to set.
NEW
532
                dbSourceNodeID, _, err := s.getSourceNode(
×
NEW
533
                        ctx, db, lnwire.GossipVersion1,
×
NEW
534
                )
×
535
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
536
                        return fmt.Errorf("unable to fetch source node: %w",
×
537
                                err)
×
538
                } else if err == nil {
×
539
                        if dbSourceNodeID != id {
×
540
                                return fmt.Errorf("v1 source node already "+
×
541
                                        "set to a different node: %d vs %d",
×
542
                                        dbSourceNodeID, id)
×
543
                        }
×
544

545
                        return nil
×
546
                }
547

548
                return db.AddSourceNode(ctx, id)
×
549
        }, sqldb.NoOpReset)
550
}
551

552
// NodeUpdatesInHorizon returns all the known lightning node which have an
553
// update timestamp within the passed range. This method can be used by two
554
// nodes to quickly determine if they have the same set of up to date node
555
// announcements.
556
//
557
// NOTE: This is part of the V1Store interface.
558
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
NEW
559
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
560

×
561
        cfg := defaultIteratorConfig()
×
562
        for _, opt := range opts {
×
563
                opt(cfg)
×
564
        }
×
565

NEW
566
        return func(yield func(*models.Node, error) bool) {
×
567
                var (
×
568
                        ctx            = context.TODO()
×
569
                        lastUpdateTime sql.NullInt64
×
570
                        lastPubKey     = make([]byte, 33)
×
571
                        hasMore        = true
×
572
                )
×
573

×
574
                // Each iteration, we'll read a batch amount of nodes, yield
×
575
                // them, then decide is we have more or not.
×
576
                for hasMore {
×
NEW
577
                        var batch []*models.Node
×
578

×
579
                        //nolint:ll
×
580
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
581
                                //nolint:ll
×
582
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
583
                                        StartTime: sqldb.SQLInt64(
×
584
                                                startTime.Unix(),
×
585
                                        ),
×
586
                                        EndTime: sqldb.SQLInt64(
×
587
                                                endTime.Unix(),
×
588
                                        ),
×
589
                                        LastUpdate: lastUpdateTime,
×
590
                                        LastPubKey: lastPubKey,
×
591
                                        OnlyPublic: sql.NullBool{
×
592
                                                Bool:  cfg.iterPublicNodes,
×
593
                                                Valid: true,
×
594
                                        },
×
595
                                        MaxResults: sqldb.SQLInt32(
×
596
                                                cfg.nodeUpdateIterBatchSize,
×
597
                                        ),
×
598
                                }
×
599
                                rows, err := db.GetNodesByLastUpdateRange(
×
600
                                        ctx, params,
×
601
                                )
×
602
                                if err != nil {
×
603
                                        return err
×
604
                                }
×
605

606
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
607

×
608
                                err = forEachNodeInBatch(
×
609
                                        ctx, s.cfg.QueryCfg, db, rows,
×
610
                                        func(_ int64, node *models.Node) error {
×
NEW
611
                                                batch = append(batch, node)
×
612

×
613
                                                // Update pagination cursors
×
614
                                                // based on the last processed
×
615
                                                // node.
×
616
                                                lastUpdateTime = sql.NullInt64{
×
617
                                                        Int64: node.LastUpdate.
×
618
                                                                Unix(),
×
619
                                                        Valid: true,
×
620
                                                }
×
621
                                                lastPubKey = node.PubKeyBytes[:]
×
622

×
623
                                                return nil
×
624
                                        },
×
625
                                )
626
                                if err != nil {
×
627
                                        return fmt.Errorf("unable to build "+
×
628
                                                "nodes: %w", err)
×
629
                                }
×
630

631
                                return nil
×
632
                        }, func() {
×
NEW
633
                                batch = []*models.Node{}
×
634
                        })
×
635

636
                        if err != nil {
×
637
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
638
                                        "error: %v", err)
×
639

×
NEW
640
                                yield(&models.Node{}, err)
×
641

×
642
                                return
×
643
                        }
×
644

645
                        for _, node := range batch {
×
646
                                if !yield(node, nil) {
×
647
                                        return
×
648
                                }
×
649
                        }
650

651
                        // If the batch didn't yield anything, then we're done.
652
                        if len(batch) == 0 {
×
653
                                break
×
654
                        }
655
                }
656
        }
657
}
658

659
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
660
// undirected edge from the two target nodes are created. The information stored
661
// denotes the static attributes of the channel, such as the channelID, the keys
662
// involved in creation of the channel, and the set of features that the channel
663
// supports. The chanPoint and chanID are used to uniquely identify the edge
664
// globally within the database.
665
//
666
// NOTE: part of the V1Store interface.
667
func (s *SQLStore) AddChannelEdge(ctx context.Context,
668
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
669

×
670
        var alreadyExists bool
×
671
        r := &batch.Request[SQLQueries]{
×
672
                Opts: batch.NewSchedulerOptions(opts...),
×
673
                Reset: func() {
×
674
                        alreadyExists = false
×
675
                },
×
676
                Do: func(tx SQLQueries) error {
×
677
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
678

×
679
                        // Make sure that the channel doesn't already exist. We
×
680
                        // do this explicitly instead of relying on catching a
×
681
                        // unique constraint error because relying on SQL to
×
682
                        // throw that error would abort the entire batch of
×
683
                        // transactions.
×
684
                        _, err := tx.GetChannelBySCID(
×
685
                                ctx, sqlc.GetChannelBySCIDParams{
×
686
                                        Scid:    chanIDB,
×
NEW
687
                                        Version: int16(lnwire.GossipVersion1),
×
688
                                },
×
689
                        )
×
690
                        if err == nil {
×
691
                                alreadyExists = true
×
692
                                return nil
×
693
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
694
                                return fmt.Errorf("unable to fetch channel: %w",
×
695
                                        err)
×
696
                        }
×
697

698
                        return insertChannel(ctx, tx, edge)
×
699
                },
700
                OnCommit: func(err error) error {
×
701
                        switch {
×
702
                        case err != nil:
×
703
                                return err
×
704
                        case alreadyExists:
×
705
                                return ErrEdgeAlreadyExist
×
706
                        default:
×
707
                                s.rejectCache.remove(edge.ChannelID)
×
708
                                s.chanCache.remove(edge.ChannelID)
×
709
                                return nil
×
710
                        }
711
                },
712
        }
713

714
        return s.chanScheduler.Execute(ctx, r)
×
715
}
716

717
// HighestChanID returns the "highest" known channel ID in the channel graph.
718
// This represents the "newest" channel from the PoV of the chain. This method
719
// can be used by peers to quickly determine if their graphs are in sync.
720
//
721
// NOTE: This is part of the V1Store interface.
722
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
723
        var highestChanID uint64
×
724
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
725
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
726
                if errors.Is(err, sql.ErrNoRows) {
×
727
                        return nil
×
728
                } else if err != nil {
×
729
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
730
                                err)
×
731
                }
×
732

733
                highestChanID = byteOrder.Uint64(chanID)
×
734

×
735
                return nil
×
736
        }, sqldb.NoOpReset)
737
        if err != nil {
×
738
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
739
        }
×
740

741
        return highestChanID, nil
×
742
}
743

744
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
745
// within the database for the referenced channel. The `flags` attribute within
746
// the ChannelEdgePolicy determines which of the directed edges are being
747
// updated. If the flag is 1, then the first node's information is being
748
// updated, otherwise it's the second node's information. The node ordering is
749
// determined by the lexicographical ordering of the identity public keys of the
750
// nodes on either side of the channel.
751
//
752
// NOTE: part of the V1Store interface.
753
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
754
        edge *models.ChannelEdgePolicy,
755
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
756

×
757
        var (
×
758
                isUpdate1    bool
×
759
                edgeNotFound bool
×
760
                from, to     route.Vertex
×
761
        )
×
762

×
763
        r := &batch.Request[SQLQueries]{
×
764
                Opts: batch.NewSchedulerOptions(opts...),
×
765
                Reset: func() {
×
766
                        isUpdate1 = false
×
767
                        edgeNotFound = false
×
768
                },
×
769
                Do: func(tx SQLQueries) error {
×
770
                        var err error
×
771
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
772
                                ctx, tx, edge,
×
773
                        )
×
774
                        // It is possible that two of the same policy
×
775
                        // announcements are both being processed in the same
×
776
                        // batch. This may case the UpsertEdgePolicy conflict to
×
777
                        // be hit since we require at the db layer that the
×
778
                        // new last_update is greater than the existing
×
779
                        // last_update. We need to gracefully handle this here.
×
780
                        if errors.Is(err, sql.ErrNoRows) {
×
781
                                return nil
×
782
                        } else if err != nil {
×
783
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
784
                        }
×
785

786
                        // Silence ErrEdgeNotFound so that the batch can
787
                        // succeed, but propagate the error via local state.
788
                        if errors.Is(err, ErrEdgeNotFound) {
×
789
                                edgeNotFound = true
×
790
                                return nil
×
791
                        }
×
792

793
                        return err
×
794
                },
795
                OnCommit: func(err error) error {
×
796
                        switch {
×
797
                        case err != nil:
×
798
                                return err
×
799
                        case edgeNotFound:
×
800
                                return ErrEdgeNotFound
×
801
                        default:
×
802
                                s.updateEdgeCache(edge, isUpdate1)
×
803
                                return nil
×
804
                        }
805
                },
806
        }
807

808
        err := s.chanScheduler.Execute(ctx, r)
×
809

×
810
        return from, to, err
×
811
}
812

813
// updateEdgeCache updates our reject and channel caches with the new
814
// edge policy information.
815
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
816
        isUpdate1 bool) {
×
817

×
818
        // If an entry for this channel is found in reject cache, we'll modify
×
819
        // the entry with the updated timestamp for the direction that was just
×
820
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
821
        // during the next query for this edge.
×
822
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
823
                if isUpdate1 {
×
824
                        entry.upd1Time = e.LastUpdate.Unix()
×
825
                } else {
×
826
                        entry.upd2Time = e.LastUpdate.Unix()
×
827
                }
×
828
                s.rejectCache.insert(e.ChannelID, entry)
×
829
        }
830

831
        // If an entry for this channel is found in channel cache, we'll modify
832
        // the entry with the updated policy for the direction that was just
833
        // written. If the edge doesn't exist, we'll defer loading the info and
834
        // policies and lazily read from disk during the next query.
835
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
836
                if isUpdate1 {
×
837
                        channel.Policy1 = e
×
838
                } else {
×
839
                        channel.Policy2 = e
×
840
                }
×
841
                s.chanCache.insert(e.ChannelID, channel)
×
842
        }
843
}
844

845
// ForEachSourceNodeChannel iterates through all channels of the source node,
846
// executing the passed callback on each. The call-back is provided with the
847
// channel's outpoint, whether we have a policy for the channel and the channel
848
// peer's node information.
849
//
850
// NOTE: part of the V1Store interface.
851
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
852
        cb func(chanPoint wire.OutPoint, havePolicy bool,
853
                otherNode *models.Node) error, reset func()) error {
×
854

×
855
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
856
                nodeID, nodePub, err := s.getSourceNode(
×
NEW
857
                        ctx, db, lnwire.GossipVersion1,
×
NEW
858
                )
×
859
                if err != nil {
×
860
                        return fmt.Errorf("unable to fetch source node: %w",
×
861
                                err)
×
862
                }
×
863

864
                return forEachNodeChannel(
×
865
                        ctx, db, s.cfg, nodeID,
×
866
                        func(info *models.ChannelEdgeInfo,
×
867
                                outPolicy *models.ChannelEdgePolicy,
×
868
                                _ *models.ChannelEdgePolicy) error {
×
869

×
870
                                // Fetch the other node.
×
871
                                var (
×
872
                                        otherNodePub [33]byte
×
873
                                        node1        = info.NodeKey1Bytes
×
874
                                        node2        = info.NodeKey2Bytes
×
875
                                )
×
876
                                switch {
×
877
                                case bytes.Equal(node1[:], nodePub[:]):
×
878
                                        otherNodePub = node2
×
879
                                case bytes.Equal(node2[:], nodePub[:]):
×
880
                                        otherNodePub = node1
×
881
                                default:
×
882
                                        return fmt.Errorf("node not " +
×
883
                                                "participating in this channel")
×
884
                                }
885

886
                                _, otherNode, err := getNodeByPubKey(
×
887
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
888
                                )
×
889
                                if err != nil {
×
890
                                        return fmt.Errorf("unable to fetch "+
×
891
                                                "other node(%x): %w",
×
892
                                                otherNodePub, err)
×
893
                                }
×
894

895
                                return cb(
×
896
                                        info.ChannelPoint, outPolicy != nil,
×
897
                                        otherNode,
×
898
                                )
×
899
                        },
900
                )
901
        }, reset)
902
}
903

904
// ForEachNode iterates through all the stored vertices/nodes in the graph,
905
// executing the passed callback with each node encountered. If the callback
906
// returns an error, then the transaction is aborted and the iteration stops
907
// early.
908
//
909
// NOTE: part of the V1Store interface.
910
func (s *SQLStore) ForEachNode(ctx context.Context,
911
        cb func(node *models.Node) error, reset func()) error {
×
912

×
913
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
914
                return forEachNodePaginated(
×
915
                        ctx, s.cfg.QueryCfg, db,
×
NEW
916
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
917
                                node *models.Node) error {
×
918

×
919
                                return cb(node)
×
920
                        },
×
921
                )
922
        }, reset)
923
}
924

925
// ForEachNodeDirectedChannel iterates through all channels of a given node,
926
// executing the passed callback on the directed edge representing the channel
927
// and its incoming policy. If the callback returns an error, then the iteration
928
// is halted with the error propagated back up to the caller.
929
//
930
// Unknown policies are passed into the callback as nil values.
931
//
932
// NOTE: this is part of the graphdb.NodeTraverser interface.
933
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
934
        cb func(channel *DirectedChannel) error, reset func()) error {
×
935

×
936
        var ctx = context.TODO()
×
937

×
938
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
939
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
940
        }, reset)
×
941
}
942

943
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
944
// graph, executing the passed callback with each node encountered. If the
945
// callback returns an error, then the transaction is aborted and the iteration
946
// stops early.
947
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
948
        cb func(route.Vertex, *lnwire.FeatureVector) error,
949
        reset func()) error {
×
950

×
951
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
952
                return forEachNodeCacheable(
×
953
                        ctx, s.cfg.QueryCfg, db,
×
954
                        func(_ int64, nodePub route.Vertex,
×
955
                                features *lnwire.FeatureVector) error {
×
956

×
957
                                return cb(nodePub, features)
×
958
                        },
×
959
                )
960
        }, reset)
961
        if err != nil {
×
962
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
963
        }
×
964

965
        return nil
×
966
}
967

968
// ForEachNodeChannel iterates through all channels of the given node,
969
// executing the passed callback with an edge info structure and the policies
970
// of each end of the channel. The first edge policy is the outgoing edge *to*
971
// the connecting node, while the second is the incoming edge *from* the
972
// connecting node. If the callback returns an error, then the iteration is
973
// halted with the error propagated back up to the caller.
974
//
975
// Unknown policies are passed into the callback as nil values.
976
//
977
// NOTE: part of the V1Store interface.
978
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
979
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
980
                *models.ChannelEdgePolicy) error, reset func()) error {
×
981

×
982
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
983
                dbNode, err := db.GetNodeByPubKey(
×
984
                        ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
985
                                Version: int16(lnwire.GossipVersion1),
×
986
                                PubKey:  nodePub[:],
×
987
                        },
×
988
                )
×
989
                if errors.Is(err, sql.ErrNoRows) {
×
990
                        return nil
×
991
                } else if err != nil {
×
992
                        return fmt.Errorf("unable to fetch node: %w", err)
×
993
                }
×
994

995
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
996
        }, reset)
997
}
998

999
// extractMaxUpdateTime returns the maximum of the two policy update times.
1000
// This is used for pagination cursor tracking.
1001
func extractMaxUpdateTime(
1002
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
1003

×
1004
        switch {
×
1005
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
1006
                return max(row.Policy1LastUpdate.Int64,
×
1007
                        row.Policy2LastUpdate.Int64)
×
1008
        case row.Policy1LastUpdate.Valid:
×
1009
                return row.Policy1LastUpdate.Int64
×
1010
        case row.Policy2LastUpdate.Valid:
×
1011
                return row.Policy2LastUpdate.Int64
×
1012
        default:
×
1013
                return 0
×
1014
        }
1015
}
1016

1017
// buildChannelFromRow constructs a ChannelEdge from a database row.
1018
// This includes building the nodes, channel info, and policies.
1019
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1020
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
1021

×
1022
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
1023
        if err != nil {
×
1024
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
1025
                        err)
×
1026
        }
×
1027

1028
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1029
        if err != nil {
×
1030
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1031
                        err)
×
1032
        }
×
1033

1034
        channel, err := getAndBuildEdgeInfo(
×
1035
                ctx, s.cfg, db,
×
1036
                row.GraphChannel, node1.PubKeyBytes,
×
1037
                node2.PubKeyBytes,
×
1038
        )
×
1039
        if err != nil {
×
1040
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1041
                        "channel info: %w", err)
×
1042
        }
×
1043

1044
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1045
        if err != nil {
×
1046
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1047
                        "channel policies: %w", err)
×
1048
        }
×
1049

1050
        p1, p2, err := getAndBuildChanPolicies(
×
1051
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1052
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1053
        )
×
1054
        if err != nil {
×
1055
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1056
                        "channel policies: %w", err)
×
1057
        }
×
1058

1059
        return ChannelEdge{
×
1060
                Info:    channel,
×
1061
                Policy1: p1,
×
1062
                Policy2: p2,
×
1063
                Node1:   node1,
×
1064
                Node2:   node2,
×
1065
        }, nil
×
1066
}
1067

1068
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1069
// This method acquires the cache lock only once for the entire batch.
1070
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1071
        if len(edgesToCache) == 0 {
×
1072
                return
×
1073
        }
×
1074

1075
        s.cacheMu.Lock()
×
1076
        defer s.cacheMu.Unlock()
×
1077

×
1078
        for chanID, edge := range edgesToCache {
×
1079
                s.chanCache.insert(chanID, edge)
×
1080
        }
×
1081
}
1082

1083
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1084
// one edge that has an update timestamp within the specified horizon.
1085
//
1086
// Iterator Lifecycle:
1087
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1088
// 2. Query batch of channels with policies in time range
1089
// 3. For each channel: check if seen, check cache, or build from DB
1090
// 4. Yield channels to caller
1091
// 5. Update cache after successful batch
1092
// 6. Repeat with updated pagination cursor until no more results
1093
//
1094
// NOTE: This is part of the V1Store interface.
1095
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
1096
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
1097

×
1098
        // Apply options.
×
1099
        cfg := defaultIteratorConfig()
×
1100
        for _, opt := range opts {
×
1101
                opt(cfg)
×
1102
        }
×
1103

1104
        return func(yield func(ChannelEdge, error) bool) {
×
1105
                var (
×
1106
                        ctx            = context.TODO()
×
1107
                        edgesSeen      = make(map[uint64]struct{})
×
1108
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1109
                        hits           int
×
1110
                        total          int
×
1111
                        lastUpdateTime sql.NullInt64
×
1112
                        lastID         sql.NullInt64
×
1113
                        hasMore        = true
×
1114
                )
×
1115

×
1116
                // Each iteration, we'll read a batch amount of channel updates
×
1117
                // (consulting the cache along the way), yield them, then loop
×
1118
                // back to decide if we have any more updates to read out.
×
1119
                for hasMore {
×
1120
                        var batch []ChannelEdge
×
1121

×
1122
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1123
                                func(db SQLQueries) error {
×
1124
                                        //nolint:ll
×
1125
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
NEW
1126
                                                Version: int16(lnwire.GossipVersion1),
×
1127
                                                StartTime: sqldb.SQLInt64(
×
1128
                                                        startTime.Unix(),
×
1129
                                                ),
×
1130
                                                EndTime: sqldb.SQLInt64(
×
1131
                                                        endTime.Unix(),
×
1132
                                                ),
×
1133
                                                LastUpdateTime: lastUpdateTime,
×
1134
                                                LastID:         lastID,
×
1135
                                                MaxResults: sql.NullInt32{
×
1136
                                                        Int32: int32(
×
1137
                                                                cfg.chanUpdateIterBatchSize,
×
1138
                                                        ),
×
1139
                                                        Valid: true,
×
1140
                                                },
×
1141
                                        }
×
1142
                                        //nolint:ll
×
1143
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1144
                                                ctx, params,
×
1145
                                        )
×
1146
                                        if err != nil {
×
1147
                                                return err
×
1148
                                        }
×
1149

1150
                                        //nolint:ll
1151
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1152

×
1153
                                        //nolint:ll
×
1154
                                        for _, row := range rows {
×
1155
                                                lastUpdateTime = sql.NullInt64{
×
1156
                                                        Int64: extractMaxUpdateTime(row),
×
1157
                                                        Valid: true,
×
1158
                                                }
×
1159
                                                lastID = sql.NullInt64{
×
1160
                                                        Int64: row.GraphChannel.ID,
×
1161
                                                        Valid: true,
×
1162
                                                }
×
1163

×
1164
                                                // Skip if we've already
×
1165
                                                // processed this channel.
×
1166
                                                chanIDInt := byteOrder.Uint64(
×
1167
                                                        row.GraphChannel.Scid,
×
1168
                                                )
×
1169
                                                _, ok := edgesSeen[chanIDInt]
×
1170
                                                if ok {
×
1171
                                                        continue
×
1172
                                                }
1173

1174
                                                s.cacheMu.RLock()
×
1175
                                                channel, ok := s.chanCache.get(
×
1176
                                                        chanIDInt,
×
1177
                                                )
×
1178
                                                s.cacheMu.RUnlock()
×
1179
                                                if ok {
×
1180
                                                        hits++
×
1181
                                                        total++
×
1182
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1183
                                                        batch = append(batch, channel)
×
1184

×
1185
                                                        continue
×
1186
                                                }
1187

1188
                                                chanEdge, err := s.buildChannelFromRow(
×
1189
                                                        ctx, db, row,
×
1190
                                                )
×
1191
                                                if err != nil {
×
1192
                                                        return err
×
1193
                                                }
×
1194

1195
                                                edgesSeen[chanIDInt] = struct{}{}
×
1196
                                                edgesToCache[chanIDInt] = chanEdge
×
1197

×
1198
                                                batch = append(batch, chanEdge)
×
1199

×
1200
                                                total++
×
1201
                                        }
1202

1203
                                        return nil
×
1204
                                }, func() {
×
1205
                                        batch = nil
×
1206
                                        edgesSeen = make(map[uint64]struct{})
×
1207
                                        edgesToCache = make(
×
1208
                                                map[uint64]ChannelEdge,
×
1209
                                        )
×
1210
                                })
×
1211

1212
                        if err != nil {
×
1213
                                log.Errorf("ChanUpdatesInHorizon "+
×
1214
                                        "batch error: %v", err)
×
1215

×
1216
                                yield(ChannelEdge{}, err)
×
1217

×
1218
                                return
×
1219
                        }
×
1220

1221
                        for _, edge := range batch {
×
1222
                                if !yield(edge, nil) {
×
1223
                                        return
×
1224
                                }
×
1225
                        }
1226

1227
                        // Update cache after successful batch yield, setting
1228
                        // the cache lock only once for the entire batch.
1229
                        s.updateChanCacheBatch(edgesToCache)
×
1230
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1231

×
1232
                        // If the batch didn't yield anything, then we're done.
×
1233
                        if len(batch) == 0 {
×
1234
                                break
×
1235
                        }
1236
                }
1237

1238
                if total > 0 {
×
1239
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1240
                                "%.2f (%d/%d)",
×
1241
                                float64(hits)*100/float64(total), hits, total)
×
1242
                } else {
×
1243
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
1244
                                "in horizon (%s, %s)", startTime, endTime)
×
1245
                }
×
1246
        }
1247
}
1248

1249
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1250
// data to the call-back. If withAddrs is true, then the call-back will also be
1251
// provided with the addresses associated with the node. The address retrieval
1252
// result in an additional round-trip to the database, so it should only be used
1253
// if the addresses are actually needed.
1254
//
1255
// NOTE: part of the V1Store interface.
1256
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1257
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
1258
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
1259

×
1260
        type nodeCachedBatchData struct {
×
1261
                features      map[int64][]int
×
1262
                addrs         map[int64][]nodeAddress
×
1263
                chanBatchData *batchChannelData
×
1264
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1265
        }
×
1266

×
1267
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1268
                // pageQueryFunc is used to query the next page of nodes.
×
1269
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1270
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1271

×
1272
                        return db.ListNodeIDsAndPubKeys(
×
1273
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
NEW
1274
                                        Version: int16(lnwire.GossipVersion1),
×
1275
                                        ID:      lastID,
×
1276
                                        Limit:   limit,
×
1277
                                },
×
1278
                        )
×
1279
                }
×
1280

1281
                // batchDataFunc is then used to batch load the data required
1282
                // for each page of nodes.
1283
                batchDataFunc := func(ctx context.Context,
×
1284
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1285

×
1286
                        // Batch load node features.
×
1287
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
1288
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1289
                        )
×
1290
                        if err != nil {
×
1291
                                return nil, fmt.Errorf("unable to batch load "+
×
1292
                                        "node features: %w", err)
×
1293
                        }
×
1294

1295
                        // Maybe fetch the node's addresses if requested.
1296
                        var nodeAddrs map[int64][]nodeAddress
×
1297
                        if withAddrs {
×
1298
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1299
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1300
                                )
×
1301
                                if err != nil {
×
1302
                                        return nil, fmt.Errorf("unable to "+
×
1303
                                                "batch load node "+
×
1304
                                                "addresses: %w", err)
×
1305
                                }
×
1306
                        }
1307

1308
                        // Batch load ALL unique channels for ALL nodes in this
1309
                        // page.
1310
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1311
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
NEW
1312
                                        Version:  int16(lnwire.GossipVersion1),
×
1313
                                        Node1Ids: nodeIDs,
×
1314
                                        Node2Ids: nodeIDs,
×
1315
                                },
×
1316
                        )
×
1317
                        if err != nil {
×
1318
                                return nil, fmt.Errorf("unable to batch "+
×
1319
                                        "fetch channels for nodes: %w", err)
×
1320
                        }
×
1321

1322
                        // Deduplicate channels and collect IDs.
1323
                        var (
×
1324
                                allChannelIDs []int64
×
1325
                                allPolicyIDs  []int64
×
1326
                        )
×
1327
                        uniqueChannels := make(
×
1328
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
1329
                        )
×
1330

×
1331
                        for _, channel := range allChannels {
×
1332
                                channelID := channel.GraphChannel.ID
×
1333

×
1334
                                // Only process each unique channel once.
×
1335
                                _, exists := uniqueChannels[channelID]
×
1336
                                if exists {
×
1337
                                        continue
×
1338
                                }
1339

1340
                                uniqueChannels[channelID] = channel
×
1341
                                allChannelIDs = append(allChannelIDs, channelID)
×
1342

×
1343
                                if channel.Policy1ID.Valid {
×
1344
                                        allPolicyIDs = append(
×
1345
                                                allPolicyIDs,
×
1346
                                                channel.Policy1ID.Int64,
×
1347
                                        )
×
1348
                                }
×
1349
                                if channel.Policy2ID.Valid {
×
1350
                                        allPolicyIDs = append(
×
1351
                                                allPolicyIDs,
×
1352
                                                channel.Policy2ID.Int64,
×
1353
                                        )
×
1354
                                }
×
1355
                        }
1356

1357
                        // Batch load channel data for all unique channels.
1358
                        channelBatchData, err := batchLoadChannelData(
×
1359
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1360
                                allPolicyIDs,
×
1361
                        )
×
1362
                        if err != nil {
×
1363
                                return nil, fmt.Errorf("unable to batch "+
×
1364
                                        "load channel data: %w", err)
×
1365
                        }
×
1366

1367
                        // Create map of node ID to channels that involve this
1368
                        // node.
1369
                        nodeIDSet := make(map[int64]bool)
×
1370
                        for _, nodeID := range nodeIDs {
×
1371
                                nodeIDSet[nodeID] = true
×
1372
                        }
×
1373

1374
                        nodeChannelMap := make(
×
1375
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1376
                        )
×
1377
                        for _, channel := range uniqueChannels {
×
1378
                                // Add channel to both nodes if they're in our
×
1379
                                // current page.
×
1380
                                node1 := channel.GraphChannel.NodeID1
×
1381
                                if nodeIDSet[node1] {
×
1382
                                        nodeChannelMap[node1] = append(
×
1383
                                                nodeChannelMap[node1], channel,
×
1384
                                        )
×
1385
                                }
×
1386
                                node2 := channel.GraphChannel.NodeID2
×
1387
                                if nodeIDSet[node2] {
×
1388
                                        nodeChannelMap[node2] = append(
×
1389
                                                nodeChannelMap[node2], channel,
×
1390
                                        )
×
1391
                                }
×
1392
                        }
1393

1394
                        return &nodeCachedBatchData{
×
1395
                                features:      nodeFeatures,
×
1396
                                addrs:         nodeAddrs,
×
1397
                                chanBatchData: channelBatchData,
×
1398
                                chanMap:       nodeChannelMap,
×
1399
                        }, nil
×
1400
                }
1401

1402
                // processItem is used to process each node in the current page.
1403
                processItem := func(ctx context.Context,
×
1404
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1405
                        batchData *nodeCachedBatchData) error {
×
1406

×
1407
                        // Build feature vector for this node.
×
1408
                        fv := lnwire.EmptyFeatureVector()
×
1409
                        features, exists := batchData.features[nodeData.ID]
×
1410
                        if exists {
×
1411
                                for _, bit := range features {
×
1412
                                        fv.Set(lnwire.FeatureBit(bit))
×
1413
                                }
×
1414
                        }
1415

1416
                        var nodePub route.Vertex
×
1417
                        copy(nodePub[:], nodeData.PubKey)
×
1418

×
1419
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1420

×
1421
                        toNodeCallback := func() route.Vertex {
×
1422
                                return nodePub
×
1423
                        }
×
1424

1425
                        // Build cached channels map for this node.
1426
                        channels := make(map[uint64]*DirectedChannel)
×
1427
                        for _, channelRow := range nodeChannels {
×
1428
                                directedChan, err := buildDirectedChannel(
×
1429
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1430
                                        channelRow, batchData.chanBatchData, fv,
×
1431
                                        toNodeCallback,
×
1432
                                )
×
1433
                                if err != nil {
×
1434
                                        return err
×
1435
                                }
×
1436

1437
                                channels[directedChan.ChannelID] = directedChan
×
1438
                        }
1439

1440
                        addrs, err := buildNodeAddresses(
×
1441
                                batchData.addrs[nodeData.ID],
×
1442
                        )
×
1443
                        if err != nil {
×
1444
                                return fmt.Errorf("unable to build node "+
×
1445
                                        "addresses: %w", err)
×
1446
                        }
×
1447

1448
                        return cb(ctx, nodePub, addrs, channels)
×
1449
                }
1450

1451
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1452
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1453
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
1454
                                return node.ID
×
1455
                        },
×
1456
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
1457
                                error) {
×
1458

×
1459
                                return node.ID, nil
×
1460
                        },
×
1461
                        batchDataFunc, processItem,
1462
                )
1463
        }, reset)
1464
}
1465

1466
// ForEachChannelCacheable iterates through all the channel edges stored
1467
// within the graph and invokes the passed callback for each edge. The
1468
// callback takes two edges as since this is a directed graph, both the
1469
// in/out edges are visited. If the callback returns an error, then the
1470
// transaction is aborted and the iteration stops early.
1471
//
1472
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1473
// pointer for that particular channel edge routing policy will be
1474
// passed into the callback.
1475
//
1476
// NOTE: this method is like ForEachChannel but fetches only the data
1477
// required for the graph cache.
1478
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1479
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
1480
        reset func()) error {
×
1481

×
1482
        ctx := context.TODO()
×
1483

×
1484
        handleChannel := func(_ context.Context,
×
1485
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
1486

×
1487
                node1, node2, err := buildNodeVertices(
×
1488
                        row.Node1Pubkey, row.Node2Pubkey,
×
1489
                )
×
1490
                if err != nil {
×
1491
                        return err
×
1492
                }
×
1493

1494
                edge := buildCacheableChannelInfo(
×
1495
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1496
                )
×
1497

×
1498
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1499
                if err != nil {
×
1500
                        return err
×
1501
                }
×
1502

1503
                pol1, pol2, err := buildCachedChanPolicies(
×
1504
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1505
                )
×
1506
                if err != nil {
×
1507
                        return err
×
1508
                }
×
1509

1510
                return cb(edge, pol1, pol2)
×
1511
        }
1512

1513
        extractCursor := func(
×
1514
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1515

×
1516
                return row.ID
×
1517
        }
×
1518

1519
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1520
                //nolint:ll
×
1521
                queryFunc := func(ctx context.Context, lastID int64,
×
1522
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1523
                        error) {
×
1524

×
1525
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1526
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
NEW
1527
                                        Version: int16(lnwire.GossipVersion1),
×
1528
                                        ID:      lastID,
×
1529
                                        Limit:   limit,
×
1530
                                },
×
1531
                        )
×
1532
                }
×
1533

1534
                return sqldb.ExecutePaginatedQuery(
×
1535
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1536
                        extractCursor, handleChannel,
×
1537
                )
×
1538
        }, reset)
1539
}
1540

1541
// ForEachChannel iterates through all the channel edges stored within the
1542
// graph and invokes the passed callback for each edge. The callback takes two
1543
// edges as since this is a directed graph, both the in/out edges are visited.
1544
// If the callback returns an error, then the transaction is aborted and the
1545
// iteration stops early.
1546
//
1547
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1548
// for that particular channel edge routing policy will be passed into the
1549
// callback.
1550
//
1551
// NOTE: part of the V1Store interface.
1552
func (s *SQLStore) ForEachChannel(ctx context.Context,
1553
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1554
                *models.ChannelEdgePolicy) error, reset func()) error {
×
1555

×
1556
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1557
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
1558
        }, reset)
×
1559
}
1560

1561
// FilterChannelRange returns the channel ID's of all known channels which were
1562
// mined in a block height within the passed range. The channel IDs are grouped
1563
// by their common block height. This method can be used to quickly share with a
1564
// peer the set of channels we know of within a particular range to catch them
1565
// up after a period of time offline. If withTimestamps is true then the
1566
// timestamp info of the latest received channel update messages of the channel
1567
// will be included in the response.
1568
//
1569
// NOTE: This is part of the V1Store interface.
1570
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
1571
        withTimestamps bool) ([]BlockChannelRange, error) {
×
1572

×
1573
        var (
×
1574
                ctx       = context.TODO()
×
1575
                startSCID = &lnwire.ShortChannelID{
×
1576
                        BlockHeight: startHeight,
×
1577
                }
×
1578
                endSCID = lnwire.ShortChannelID{
×
1579
                        BlockHeight: endHeight,
×
1580
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1581
                        TxPosition:  math.MaxUint16,
×
1582
                }
×
1583
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1584
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1585
        )
×
1586

×
1587
        // 1) get all channels where channelID is between start and end chan ID.
×
1588
        // 2) skip if not public (ie, no channel_proof)
×
1589
        // 3) collect that channel.
×
1590
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1591
        //    and add those timestamps to the collected channel.
×
1592
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1593
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1594
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1595
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1596
                                StartScid: chanIDStart,
×
1597
                                EndScid:   chanIDEnd,
×
1598
                        },
×
1599
                )
×
1600
                if err != nil {
×
1601
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1602
                                err)
×
1603
                }
×
1604

1605
                for _, dbChan := range dbChans {
×
1606
                        cid := lnwire.NewShortChanIDFromInt(
×
1607
                                byteOrder.Uint64(dbChan.Scid),
×
1608
                        )
×
1609
                        chanInfo := NewChannelUpdateInfo(
×
1610
                                cid, time.Time{}, time.Time{},
×
1611
                        )
×
1612

×
1613
                        if !withTimestamps {
×
1614
                                channelsPerBlock[cid.BlockHeight] = append(
×
1615
                                        channelsPerBlock[cid.BlockHeight],
×
1616
                                        chanInfo,
×
1617
                                )
×
1618

×
1619
                                continue
×
1620
                        }
1621

1622
                        //nolint:ll
1623
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1624
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1625
                                        Version:   int16(lnwire.GossipVersion1),
×
1626
                                        ChannelID: dbChan.ID,
×
1627
                                        NodeID:    dbChan.NodeID1,
×
1628
                                },
×
1629
                        )
×
1630
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1631
                                return fmt.Errorf("unable to fetch node1 "+
×
1632
                                        "policy: %w", err)
×
1633
                        } else if err == nil {
×
1634
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1635
                                        node1Policy.LastUpdate.Int64, 0,
×
1636
                                )
×
1637
                        }
×
1638

1639
                        //nolint:ll
1640
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1641
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
1642
                                        Version:   int16(lnwire.GossipVersion1),
×
1643
                                        ChannelID: dbChan.ID,
×
1644
                                        NodeID:    dbChan.NodeID2,
×
1645
                                },
×
1646
                        )
×
1647
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1648
                                return fmt.Errorf("unable to fetch node2 "+
×
1649
                                        "policy: %w", err)
×
1650
                        } else if err == nil {
×
1651
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1652
                                        node2Policy.LastUpdate.Int64, 0,
×
1653
                                )
×
1654
                        }
×
1655

1656
                        channelsPerBlock[cid.BlockHeight] = append(
×
1657
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1658
                        )
×
1659
                }
1660

1661
                return nil
×
1662
        }, func() {
×
1663
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1664
        })
×
1665
        if err != nil {
×
1666
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
1667
        }
×
1668

1669
        if len(channelsPerBlock) == 0 {
×
1670
                return nil, nil
×
1671
        }
×
1672

1673
        // Return the channel ranges in ascending block height order.
1674
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
1675
        slices.Sort(blocks)
×
1676

×
1677
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1678
                return BlockChannelRange{
×
1679
                        Height:   block,
×
1680
                        Channels: channelsPerBlock[block],
×
1681
                }
×
1682
        }), nil
×
1683
}
1684

1685
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1686
// zombie. This method is used on an ad-hoc basis, when channels need to be
1687
// marked as zombies outside the normal pruning cycle.
1688
//
1689
// NOTE: part of the V1Store interface.
1690
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
1691
        pubKey1, pubKey2 [33]byte) error {
×
1692

×
1693
        ctx := context.TODO()
×
1694

×
1695
        s.cacheMu.Lock()
×
1696
        defer s.cacheMu.Unlock()
×
1697

×
1698
        chanIDB := channelIDToBytes(chanID)
×
1699

×
1700
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1701
                return db.UpsertZombieChannel(
×
1702
                        ctx, sqlc.UpsertZombieChannelParams{
×
NEW
1703
                                Version:  int16(lnwire.GossipVersion1),
×
1704
                                Scid:     chanIDB,
×
1705
                                NodeKey1: pubKey1[:],
×
1706
                                NodeKey2: pubKey2[:],
×
1707
                        },
×
1708
                )
×
1709
        }, sqldb.NoOpReset)
×
1710
        if err != nil {
×
1711
                return fmt.Errorf("unable to upsert zombie channel "+
×
1712
                        "(channel_id=%d): %w", chanID, err)
×
1713
        }
×
1714

1715
        s.rejectCache.remove(chanID)
×
1716
        s.chanCache.remove(chanID)
×
1717

×
1718
        return nil
×
1719
}
1720

1721
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1722
//
1723
// NOTE: part of the V1Store interface.
1724
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1725
        s.cacheMu.Lock()
×
1726
        defer s.cacheMu.Unlock()
×
1727

×
1728
        var (
×
1729
                ctx     = context.TODO()
×
1730
                chanIDB = channelIDToBytes(chanID)
×
1731
        )
×
1732

×
1733
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1734
                res, err := db.DeleteZombieChannel(
×
1735
                        ctx, sqlc.DeleteZombieChannelParams{
×
1736
                                Scid:    chanIDB,
×
NEW
1737
                                Version: int16(lnwire.GossipVersion1),
×
1738
                        },
×
1739
                )
×
1740
                if err != nil {
×
1741
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1742
                                err)
×
1743
                }
×
1744

1745
                rows, err := res.RowsAffected()
×
1746
                if err != nil {
×
1747
                        return err
×
1748
                }
×
1749

1750
                if rows == 0 {
×
1751
                        return ErrZombieEdgeNotFound
×
1752
                } else if rows > 1 {
×
1753
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1754
                                "expected 1", rows)
×
1755
                }
×
1756

1757
                return nil
×
1758
        }, sqldb.NoOpReset)
1759
        if err != nil {
×
1760
                return fmt.Errorf("unable to mark edge live "+
×
1761
                        "(channel_id=%d): %w", chanID, err)
×
1762
        }
×
1763

1764
        s.rejectCache.remove(chanID)
×
1765
        s.chanCache.remove(chanID)
×
1766

×
1767
        return err
×
1768
}
1769

1770
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1771
// zombie, then the two node public keys corresponding to this edge are also
1772
// returned.
1773
//
1774
// NOTE: part of the V1Store interface.
1775
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
1776
        error) {
×
1777

×
1778
        var (
×
1779
                ctx              = context.TODO()
×
1780
                isZombie         bool
×
1781
                pubKey1, pubKey2 route.Vertex
×
1782
                chanIDB          = channelIDToBytes(chanID)
×
1783
        )
×
1784

×
1785
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1786
                zombie, err := db.GetZombieChannel(
×
1787
                        ctx, sqlc.GetZombieChannelParams{
×
1788
                                Scid:    chanIDB,
×
NEW
1789
                                Version: int16(lnwire.GossipVersion1),
×
1790
                        },
×
1791
                )
×
1792
                if errors.Is(err, sql.ErrNoRows) {
×
1793
                        return nil
×
1794
                }
×
1795
                if err != nil {
×
1796
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1797
                                err)
×
1798
                }
×
1799

1800
                copy(pubKey1[:], zombie.NodeKey1)
×
1801
                copy(pubKey2[:], zombie.NodeKey2)
×
1802
                isZombie = true
×
1803

×
1804
                return nil
×
1805
        }, sqldb.NoOpReset)
1806
        if err != nil {
×
1807
                return false, route.Vertex{}, route.Vertex{},
×
1808
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1809
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1810
        }
×
1811

1812
        return isZombie, pubKey1, pubKey2, nil
×
1813
}
1814

1815
// NumZombies returns the current number of zombie channels in the graph.
1816
//
1817
// NOTE: part of the V1Store interface.
1818
func (s *SQLStore) NumZombies() (uint64, error) {
×
1819
        var (
×
1820
                ctx        = context.TODO()
×
1821
                numZombies uint64
×
1822
        )
×
1823
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
NEW
1824
                count, err := db.CountZombieChannels(
×
NEW
1825
                        ctx, int16(lnwire.GossipVersion1),
×
NEW
1826
                )
×
1827
                if err != nil {
×
1828
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1829
                                err)
×
1830
                }
×
1831

1832
                numZombies = uint64(count)
×
1833

×
1834
                return nil
×
1835
        }, sqldb.NoOpReset)
1836
        if err != nil {
×
1837
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1838
        }
×
1839

1840
        return numZombies, nil
×
1841
}
1842

1843
// DeleteChannelEdges removes edges with the given channel IDs from the
1844
// database and marks them as zombies. This ensures that we're unable to re-add
1845
// it to our database once again. If an edge does not exist within the
1846
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1847
// true, then when we mark these edges as zombies, we'll set up the keys such
1848
// that we require the node that failed to send the fresh update to be the one
1849
// that resurrects the channel from its zombie state. The markZombie bool
1850
// denotes whether to mark the channel as a zombie.
1851
//
1852
// NOTE: part of the V1Store interface.
1853
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
1854
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
1855

×
1856
        s.cacheMu.Lock()
×
1857
        defer s.cacheMu.Unlock()
×
1858

×
1859
        // Keep track of which channels we end up finding so that we can
×
1860
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1861
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1862
        for _, chanID := range chanIDs {
×
1863
                chanLookup[chanID] = struct{}{}
×
1864
        }
×
1865

1866
        var (
×
1867
                ctx   = context.TODO()
×
1868
                edges []*models.ChannelEdgeInfo
×
1869
        )
×
1870
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1871
                // First, collect all channel rows.
×
1872
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1873
                chanCallBack := func(ctx context.Context,
×
1874
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1875

×
1876
                        // Deleting the entry from the map indicates that we
×
1877
                        // have found the channel.
×
1878
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1879
                        delete(chanLookup, scid)
×
1880

×
1881
                        channelRows = append(channelRows, row)
×
1882

×
1883
                        return nil
×
1884
                }
×
1885

1886
                err := s.forEachChanWithPoliciesInSCIDList(
×
1887
                        ctx, db, chanCallBack, chanIDs,
×
1888
                )
×
1889
                if err != nil {
×
1890
                        return err
×
1891
                }
×
1892

1893
                if len(chanLookup) > 0 {
×
1894
                        return ErrEdgeNotFound
×
1895
                }
×
1896

1897
                if len(channelRows) == 0 {
×
1898
                        return nil
×
1899
                }
×
1900

1901
                // Batch build all channel edges.
1902
                var chanIDsToDelete []int64
×
1903
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1904
                        ctx, s.cfg, db, channelRows,
×
1905
                )
×
1906
                if err != nil {
×
1907
                        return err
×
1908
                }
×
1909

1910
                if markZombie {
×
1911
                        for i, row := range channelRows {
×
1912
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1913

×
1914
                                err := handleZombieMarking(
×
1915
                                        ctx, db, row, edges[i],
×
1916
                                        strictZombiePruning, scid,
×
1917
                                )
×
1918
                                if err != nil {
×
1919
                                        return fmt.Errorf("unable to mark "+
×
1920
                                                "channel as zombie: %w", err)
×
1921
                                }
×
1922
                        }
1923
                }
1924

1925
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1926
        }, func() {
×
1927
                edges = nil
×
1928

×
1929
                // Re-fill the lookup map.
×
1930
                for _, chanID := range chanIDs {
×
1931
                        chanLookup[chanID] = struct{}{}
×
1932
                }
×
1933
        })
1934
        if err != nil {
×
1935
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1936
                        err)
×
1937
        }
×
1938

1939
        for _, chanID := range chanIDs {
×
1940
                s.rejectCache.remove(chanID)
×
1941
                s.chanCache.remove(chanID)
×
1942
        }
×
1943

1944
        return edges, nil
×
1945
}
1946

1947
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1948
// channel identified by the channel ID. If the channel can't be found, then
1949
// ErrEdgeNotFound is returned. A struct which houses the general information
1950
// for the channel itself is returned as well as two structs that contain the
1951
// routing policies for the channel in either direction.
1952
//
1953
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1954
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1955
// the ChannelEdgeInfo will only include the public keys of each node.
1956
//
1957
// NOTE: part of the V1Store interface.
1958
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1959
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1960
        *models.ChannelEdgePolicy, error) {
×
1961

×
1962
        var (
×
1963
                ctx              = context.TODO()
×
1964
                edge             *models.ChannelEdgeInfo
×
1965
                policy1, policy2 *models.ChannelEdgePolicy
×
1966
                chanIDB          = channelIDToBytes(chanID)
×
1967
        )
×
1968
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1969
                row, err := db.GetChannelBySCIDWithPolicies(
×
1970
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1971
                                Scid:    chanIDB,
×
NEW
1972
                                Version: int16(lnwire.GossipVersion1),
×
1973
                        },
×
1974
                )
×
1975
                if errors.Is(err, sql.ErrNoRows) {
×
1976
                        // First check if this edge is perhaps in the zombie
×
1977
                        // index.
×
1978
                        zombie, err := db.GetZombieChannel(
×
1979
                                ctx, sqlc.GetZombieChannelParams{
×
1980
                                        Scid:    chanIDB,
×
NEW
1981
                                        Version: int16(lnwire.GossipVersion1),
×
1982
                                },
×
1983
                        )
×
1984
                        if errors.Is(err, sql.ErrNoRows) {
×
1985
                                return ErrEdgeNotFound
×
1986
                        } else if err != nil {
×
1987
                                return fmt.Errorf("unable to check if "+
×
1988
                                        "channel is zombie: %w", err)
×
1989
                        }
×
1990

1991
                        // At this point, we know the channel is a zombie, so
1992
                        // we'll return an error indicating this, and we will
1993
                        // populate the edge info with the public keys of each
1994
                        // party as this is the only information we have about
1995
                        // it.
1996
                        edge = &models.ChannelEdgeInfo{}
×
1997
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
1998
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
1999

×
2000
                        return ErrZombieEdge
×
2001
                } else if err != nil {
×
2002
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2003
                }
×
2004

2005
                node1, node2, err := buildNodeVertices(
×
2006
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2007
                )
×
2008
                if err != nil {
×
2009
                        return err
×
2010
                }
×
2011

2012
                edge, err = getAndBuildEdgeInfo(
×
2013
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2014
                )
×
2015
                if err != nil {
×
2016
                        return fmt.Errorf("unable to build channel info: %w",
×
2017
                                err)
×
2018
                }
×
2019

2020
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2021
                if err != nil {
×
2022
                        return fmt.Errorf("unable to extract channel "+
×
2023
                                "policies: %w", err)
×
2024
                }
×
2025

2026
                policy1, policy2, err = getAndBuildChanPolicies(
×
2027
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2028
                        node1, node2,
×
2029
                )
×
2030
                if err != nil {
×
2031
                        return fmt.Errorf("unable to build channel "+
×
2032
                                "policies: %w", err)
×
2033
                }
×
2034

2035
                return nil
×
2036
        }, sqldb.NoOpReset)
2037
        if err != nil {
×
2038
                // If we are returning the ErrZombieEdge, then we also need to
×
2039
                // return the edge info as the method comment indicates that
×
2040
                // this will be populated when the edge is a zombie.
×
2041
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2042
                        err)
×
2043
        }
×
2044

2045
        return edge, policy1, policy2, nil
×
2046
}
2047

2048
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2049
// the channel identified by the funding outpoint. If the channel can't be
2050
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2051
// information for the channel itself is returned as well as two structs that
2052
// contain the routing policies for the channel in either direction.
2053
//
2054
// NOTE: part of the V1Store interface.
2055
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2056
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
2057
        *models.ChannelEdgePolicy, error) {
×
2058

×
2059
        var (
×
2060
                ctx              = context.TODO()
×
2061
                edge             *models.ChannelEdgeInfo
×
2062
                policy1, policy2 *models.ChannelEdgePolicy
×
2063
        )
×
2064
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2065
                row, err := db.GetChannelByOutpointWithPolicies(
×
2066
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2067
                                Outpoint: op.String(),
×
NEW
2068
                                Version:  int16(lnwire.GossipVersion1),
×
2069
                        },
×
2070
                )
×
2071
                if errors.Is(err, sql.ErrNoRows) {
×
2072
                        return ErrEdgeNotFound
×
2073
                } else if err != nil {
×
2074
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2075
                }
×
2076

2077
                node1, node2, err := buildNodeVertices(
×
2078
                        row.Node1Pubkey, row.Node2Pubkey,
×
2079
                )
×
2080
                if err != nil {
×
2081
                        return err
×
2082
                }
×
2083

2084
                edge, err = getAndBuildEdgeInfo(
×
2085
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2086
                )
×
2087
                if err != nil {
×
2088
                        return fmt.Errorf("unable to build channel info: %w",
×
2089
                                err)
×
2090
                }
×
2091

2092
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2093
                if err != nil {
×
2094
                        return fmt.Errorf("unable to extract channel "+
×
2095
                                "policies: %w", err)
×
2096
                }
×
2097

2098
                policy1, policy2, err = getAndBuildChanPolicies(
×
2099
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2100
                        node1, node2,
×
2101
                )
×
2102
                if err != nil {
×
2103
                        return fmt.Errorf("unable to build channel "+
×
2104
                                "policies: %w", err)
×
2105
                }
×
2106

2107
                return nil
×
2108
        }, sqldb.NoOpReset)
2109
        if err != nil {
×
2110
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2111
                        err)
×
2112
        }
×
2113

2114
        return edge, policy1, policy2, nil
×
2115
}
2116

2117
// HasChannelEdge returns true if the database knows of a channel edge with the
2118
// passed channel ID, and false otherwise. If an edge with that ID is found
2119
// within the graph, then two time stamps representing the last time the edge
2120
// was updated for both directed edges are returned along with the boolean. If
2121
// it is not found, then the zombie index is checked and its result is returned
2122
// as the second boolean.
2123
//
2124
// NOTE: part of the V1Store interface.
2125
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
2126
        bool, error) {
×
2127

×
2128
        ctx := context.TODO()
×
2129

×
2130
        var (
×
2131
                exists          bool
×
2132
                isZombie        bool
×
2133
                node1LastUpdate time.Time
×
2134
                node2LastUpdate time.Time
×
2135
        )
×
2136

×
2137
        // We'll query the cache with the shared lock held to allow multiple
×
2138
        // readers to access values in the cache concurrently if they exist.
×
2139
        s.cacheMu.RLock()
×
2140
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2141
                s.cacheMu.RUnlock()
×
2142
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2143
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2144
                exists, isZombie = entry.flags.unpack()
×
2145

×
2146
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2147
        }
×
2148
        s.cacheMu.RUnlock()
×
2149

×
2150
        s.cacheMu.Lock()
×
2151
        defer s.cacheMu.Unlock()
×
2152

×
2153
        // The item was not found with the shared lock, so we'll acquire the
×
2154
        // exclusive lock and check the cache again in case another method added
×
2155
        // the entry to the cache while no lock was held.
×
2156
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2157
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2158
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2159
                exists, isZombie = entry.flags.unpack()
×
2160

×
2161
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2162
        }
×
2163

2164
        chanIDB := channelIDToBytes(chanID)
×
2165
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2166
                channel, err := db.GetChannelBySCID(
×
2167
                        ctx, sqlc.GetChannelBySCIDParams{
×
2168
                                Scid:    chanIDB,
×
NEW
2169
                                Version: int16(lnwire.GossipVersion1),
×
2170
                        },
×
2171
                )
×
2172
                if errors.Is(err, sql.ErrNoRows) {
×
2173
                        // Check if it is a zombie channel.
×
2174
                        isZombie, err = db.IsZombieChannel(
×
2175
                                ctx, sqlc.IsZombieChannelParams{
×
2176
                                        Scid:    chanIDB,
×
NEW
2177
                                        Version: int16(lnwire.GossipVersion1),
×
2178
                                },
×
2179
                        )
×
2180
                        if err != nil {
×
2181
                                return fmt.Errorf("could not check if channel "+
×
2182
                                        "is zombie: %w", err)
×
2183
                        }
×
2184

2185
                        return nil
×
2186
                } else if err != nil {
×
2187
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2188
                }
×
2189

2190
                exists = true
×
2191

×
2192
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2193
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
2194
                                Version:   int16(lnwire.GossipVersion1),
×
2195
                                ChannelID: channel.ID,
×
2196
                                NodeID:    channel.NodeID1,
×
2197
                        },
×
2198
                )
×
2199
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2200
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2201
                                err)
×
2202
                } else if err == nil {
×
2203
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2204
                }
×
2205

2206
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2207
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
NEW
2208
                                Version:   int16(lnwire.GossipVersion1),
×
2209
                                ChannelID: channel.ID,
×
2210
                                NodeID:    channel.NodeID2,
×
2211
                        },
×
2212
                )
×
2213
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2214
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2215
                                err)
×
2216
                } else if err == nil {
×
2217
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2218
                }
×
2219

2220
                return nil
×
2221
        }, sqldb.NoOpReset)
2222
        if err != nil {
×
2223
                return time.Time{}, time.Time{}, false, false,
×
2224
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2225
        }
×
2226

2227
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2228
                upd1Time: node1LastUpdate.Unix(),
×
2229
                upd2Time: node2LastUpdate.Unix(),
×
2230
                flags:    packRejectFlags(exists, isZombie),
×
2231
        })
×
2232

×
2233
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2234
}
2235

2236
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2237
// passed channel point (outpoint). If the passed channel doesn't exist within
2238
// the database, then ErrEdgeNotFound is returned.
2239
//
2240
// NOTE: part of the V1Store interface.
2241
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
2242
        var (
×
2243
                ctx       = context.TODO()
×
2244
                channelID uint64
×
2245
        )
×
2246
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2247
                chanID, err := db.GetSCIDByOutpoint(
×
2248
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2249
                                Outpoint: chanPoint.String(),
×
NEW
2250
                                Version:  int16(lnwire.GossipVersion1),
×
2251
                        },
×
2252
                )
×
2253
                if errors.Is(err, sql.ErrNoRows) {
×
2254
                        return ErrEdgeNotFound
×
2255
                } else if err != nil {
×
2256
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2257
                                err)
×
2258
                }
×
2259

2260
                channelID = byteOrder.Uint64(chanID)
×
2261

×
2262
                return nil
×
2263
        }, sqldb.NoOpReset)
2264
        if err != nil {
×
2265
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2266
        }
×
2267

2268
        return channelID, nil
×
2269
}
2270

2271
// IsPublicNode is a helper method that determines whether the node with the
2272
// given public key is seen as a public node in the graph from the graph's
2273
// source node's point of view.
2274
//
2275
// NOTE: part of the V1Store interface.
2276
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
2277
        ctx := context.TODO()
×
2278

×
2279
        var isPublic bool
×
2280
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2281
                var err error
×
2282
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2283

×
2284
                return err
×
2285
        }, sqldb.NoOpReset)
×
2286
        if err != nil {
×
2287
                return false, fmt.Errorf("unable to check if node is "+
×
2288
                        "public: %w", err)
×
2289
        }
×
2290

2291
        return isPublic, nil
×
2292
}
2293

2294
// FetchChanInfos returns the set of channel edges that correspond to the passed
2295
// channel ID's. If an edge is the query is unknown to the database, it will
2296
// skipped and the result will contain only those edges that exist at the time
2297
// of the query. This can be used to respond to peer queries that are seeking to
2298
// fill in gaps in their view of the channel graph.
2299
//
2300
// NOTE: part of the V1Store interface.
2301
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
2302
        var (
×
2303
                ctx   = context.TODO()
×
2304
                edges = make(map[uint64]ChannelEdge)
×
2305
        )
×
2306
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2307
                // First, collect all channel rows.
×
2308
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2309
                chanCallBack := func(ctx context.Context,
×
2310
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2311

×
2312
                        channelRows = append(channelRows, row)
×
2313
                        return nil
×
2314
                }
×
2315

2316
                err := s.forEachChanWithPoliciesInSCIDList(
×
2317
                        ctx, db, chanCallBack, chanIDs,
×
2318
                )
×
2319
                if err != nil {
×
2320
                        return err
×
2321
                }
×
2322

2323
                if len(channelRows) == 0 {
×
2324
                        return nil
×
2325
                }
×
2326

2327
                // Batch build all channel edges.
2328
                chans, err := batchBuildChannelEdges(
×
2329
                        ctx, s.cfg, db, channelRows,
×
2330
                )
×
2331
                if err != nil {
×
2332
                        return fmt.Errorf("unable to build channel edges: %w",
×
2333
                                err)
×
2334
                }
×
2335

2336
                for _, c := range chans {
×
2337
                        edges[c.Info.ChannelID] = c
×
2338
                }
×
2339

2340
                return err
×
2341
        }, func() {
×
2342
                clear(edges)
×
2343
        })
×
2344
        if err != nil {
×
2345
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2346
        }
×
2347

2348
        res := make([]ChannelEdge, 0, len(edges))
×
2349
        for _, chanID := range chanIDs {
×
2350
                edge, ok := edges[chanID]
×
2351
                if !ok {
×
2352
                        continue
×
2353
                }
2354

2355
                res = append(res, edge)
×
2356
        }
2357

2358
        return res, nil
×
2359
}
2360

2361
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2362
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2363
// channels in a paginated manner.
2364
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2365
        db SQLQueries, cb func(ctx context.Context,
2366
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
2367
        chanIDs []uint64) error {
×
2368

×
2369
        queryWrapper := func(ctx context.Context,
×
2370
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
2371
                error) {
×
2372

×
2373
                return db.GetChannelsBySCIDWithPolicies(
×
2374
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
NEW
2375
                                Version: int16(lnwire.GossipVersion1),
×
2376
                                Scids:   scids,
×
2377
                        },
×
2378
                )
×
2379
        }
×
2380

2381
        return sqldb.ExecuteBatchQuery(
×
2382
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2383
                cb,
×
2384
        )
×
2385
}
2386

2387
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2388
// ID's that we don't know and are not known zombies of the passed set. In other
2389
// words, we perform a set difference of our set of chan ID's and the ones
2390
// passed in. This method can be used by callers to determine the set of
2391
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2392
// known zombies is also returned.
2393
//
2394
// NOTE: part of the V1Store interface.
2395
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
2396
        []ChannelUpdateInfo, error) {
×
2397

×
2398
        var (
×
2399
                ctx          = context.TODO()
×
2400
                newChanIDs   []uint64
×
2401
                knownZombies []ChannelUpdateInfo
×
2402
                infoLookup   = make(
×
2403
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2404
                )
×
2405
        )
×
2406

×
2407
        // We first build a lookup map of the channel ID's to the
×
2408
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2409
        // already know about.
×
2410
        for _, chanInfo := range chansInfo {
×
2411
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2412
        }
×
2413

2414
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2415
                // The call-back function deletes known channels from
×
2416
                // infoLookup, so that we can later check which channels are
×
2417
                // zombies by only looking at the remaining channels in the set.
×
2418
                cb := func(ctx context.Context,
×
2419
                        channel sqlc.GraphChannel) error {
×
2420

×
2421
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2422

×
2423
                        return nil
×
2424
                }
×
2425

2426
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2427
                if err != nil {
×
2428
                        return fmt.Errorf("unable to iterate through "+
×
2429
                                "channels: %w", err)
×
2430
                }
×
2431

2432
                // We want to ensure that we deal with the channels in the
2433
                // same order that they were passed in, so we iterate over the
2434
                // original chansInfo slice and then check if that channel is
2435
                // still in the infoLookup map.
2436
                for _, chanInfo := range chansInfo {
×
2437
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
2438
                        if _, ok := infoLookup[channelID]; !ok {
×
2439
                                continue
×
2440
                        }
2441

2442
                        isZombie, err := db.IsZombieChannel(
×
2443
                                ctx, sqlc.IsZombieChannelParams{
×
2444
                                        Scid:    channelIDToBytes(channelID),
×
NEW
2445
                                        Version: int16(lnwire.GossipVersion1),
×
2446
                                },
×
2447
                        )
×
2448
                        if err != nil {
×
2449
                                return fmt.Errorf("unable to fetch zombie "+
×
2450
                                        "channel: %w", err)
×
2451
                        }
×
2452

2453
                        if isZombie {
×
2454
                                knownZombies = append(knownZombies, chanInfo)
×
2455

×
2456
                                continue
×
2457
                        }
2458

2459
                        newChanIDs = append(newChanIDs, channelID)
×
2460
                }
2461

2462
                return nil
×
2463
        }, func() {
×
2464
                newChanIDs = nil
×
2465
                knownZombies = nil
×
2466
                // Rebuild the infoLookup map in case of a rollback.
×
2467
                for _, chanInfo := range chansInfo {
×
2468
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2469
                        infoLookup[scid] = chanInfo
×
2470
                }
×
2471
        })
2472
        if err != nil {
×
2473
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2474
        }
×
2475

2476
        return newChanIDs, knownZombies, nil
×
2477
}
2478

2479
// forEachChanInSCIDList is a helper method that executes a paged query
2480
// against the database to fetch all channels that match the passed
2481
// ChannelUpdateInfo slice. The callback function is called for each channel
2482
// that is found.
2483
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2484
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
2485
        chansInfo []ChannelUpdateInfo) error {
×
2486

×
2487
        queryWrapper := func(ctx context.Context,
×
2488
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
2489

×
2490
                return db.GetChannelsBySCIDs(
×
2491
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
NEW
2492
                                Version: int16(lnwire.GossipVersion1),
×
2493
                                Scids:   scids,
×
2494
                        },
×
2495
                )
×
2496
        }
×
2497

2498
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2499
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2500

×
2501
                return channelIDToBytes(channelID)
×
2502
        }
×
2503

2504
        return sqldb.ExecuteBatchQuery(
×
2505
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2506
                cb,
×
2507
        )
×
2508
}
2509

2510
// PruneGraphNodes is a garbage collection method which attempts to prune out
2511
// any nodes from the channel graph that are currently unconnected. This ensure
2512
// that we only maintain a graph of reachable nodes. In the event that a pruned
2513
// node gains more channels, it will be re-added back to the graph.
2514
//
2515
// NOTE: this prunes nodes across protocol versions. It will never prune the
2516
// source nodes.
2517
//
2518
// NOTE: part of the V1Store interface.
2519
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
2520
        var ctx = context.TODO()
×
2521

×
2522
        var prunedNodes []route.Vertex
×
2523
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2524
                var err error
×
2525
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2526

×
2527
                return err
×
2528
        }, func() {
×
2529
                prunedNodes = nil
×
2530
        })
×
2531
        if err != nil {
×
2532
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2533
        }
×
2534

2535
        return prunedNodes, nil
×
2536
}
2537

2538
// PruneGraph prunes newly closed channels from the channel graph in response
2539
// to a new block being solved on the network. Any transactions which spend the
2540
// funding output of any known channels within he graph will be deleted.
2541
// Additionally, the "prune tip", or the last block which has been used to
2542
// prune the graph is stored so callers can ensure the graph is fully in sync
2543
// with the current UTXO state. A slice of channels that have been closed by
2544
// the target block along with any pruned nodes are returned if the function
2545
// succeeds without error.
2546
//
2547
// NOTE: part of the V1Store interface.
2548
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2549
        blockHash *chainhash.Hash, blockHeight uint32) (
2550
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
2551

×
2552
        ctx := context.TODO()
×
2553

×
2554
        s.cacheMu.Lock()
×
2555
        defer s.cacheMu.Unlock()
×
2556

×
2557
        var (
×
2558
                closedChans []*models.ChannelEdgeInfo
×
2559
                prunedNodes []route.Vertex
×
2560
        )
×
2561
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2562
                // First, collect all channel rows that need to be pruned.
×
2563
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2564
                channelCallback := func(ctx context.Context,
×
2565
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2566

×
2567
                        channelRows = append(channelRows, row)
×
2568

×
2569
                        return nil
×
2570
                }
×
2571

2572
                err := s.forEachChanInOutpoints(
×
2573
                        ctx, db, spentOutputs, channelCallback,
×
2574
                )
×
2575
                if err != nil {
×
2576
                        return fmt.Errorf("unable to fetch channels by "+
×
2577
                                "outpoints: %w", err)
×
2578
                }
×
2579

2580
                if len(channelRows) == 0 {
×
2581
                        // There are no channels to prune. So we can exit early
×
2582
                        // after updating the prune log.
×
2583
                        err = db.UpsertPruneLogEntry(
×
2584
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2585
                                        BlockHash:   blockHash[:],
×
2586
                                        BlockHeight: int64(blockHeight),
×
2587
                                },
×
2588
                        )
×
2589
                        if err != nil {
×
2590
                                return fmt.Errorf("unable to insert prune log "+
×
2591
                                        "entry: %w", err)
×
2592
                        }
×
2593

2594
                        return nil
×
2595
                }
2596

2597
                // Batch build all channel edges for pruning.
2598
                var chansToDelete []int64
×
2599
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
2600
                        ctx, s.cfg, db, channelRows,
×
2601
                )
×
2602
                if err != nil {
×
2603
                        return err
×
2604
                }
×
2605

2606
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2607
                if err != nil {
×
2608
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2609
                }
×
2610

2611
                err = db.UpsertPruneLogEntry(
×
2612
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2613
                                BlockHash:   blockHash[:],
×
2614
                                BlockHeight: int64(blockHeight),
×
2615
                        },
×
2616
                )
×
2617
                if err != nil {
×
2618
                        return fmt.Errorf("unable to insert prune log "+
×
2619
                                "entry: %w", err)
×
2620
                }
×
2621

2622
                // Now that we've pruned some channels, we'll also prune any
2623
                // nodes that no longer have any channels.
2624
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2625
                if err != nil {
×
2626
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
2627
                                err)
×
2628
                }
×
2629

2630
                return nil
×
2631
        }, func() {
×
2632
                prunedNodes = nil
×
2633
                closedChans = nil
×
2634
        })
×
2635
        if err != nil {
×
2636
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2637
        }
×
2638

2639
        for _, channel := range closedChans {
×
2640
                s.rejectCache.remove(channel.ChannelID)
×
2641
                s.chanCache.remove(channel.ChannelID)
×
2642
        }
×
2643

2644
        return closedChans, prunedNodes, nil
×
2645
}
2646

2647
// forEachChanInOutpoints is a helper function that executes a paginated
2648
// query to fetch channels by their outpoints and applies the given call-back
2649
// to each.
2650
//
2651
// NOTE: this fetches channels for all protocol versions.
2652
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2653
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
2654
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
2655

×
2656
        // Create a wrapper that uses the transaction's db instance to execute
×
2657
        // the query.
×
2658
        queryWrapper := func(ctx context.Context,
×
2659
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2660
                error) {
×
2661

×
2662
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2663
        }
×
2664

2665
        // Define the conversion function from Outpoint to string.
2666
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2667
                return outpoint.String()
×
2668
        }
×
2669

2670
        return sqldb.ExecuteBatchQuery(
×
2671
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2672
                queryWrapper, cb,
×
2673
        )
×
2674
}
2675

2676
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2677
        dbIDs []int64) error {
×
2678

×
2679
        // Create a wrapper that uses the transaction's db instance to execute
×
2680
        // the query.
×
2681
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2682
                return nil, db.DeleteChannels(ctx, ids)
×
2683
        }
×
2684

2685
        idConverter := func(id int64) int64 {
×
2686
                return id
×
2687
        }
×
2688

2689
        return sqldb.ExecuteBatchQuery(
×
2690
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2691
                queryWrapper, func(ctx context.Context, _ any) error {
×
2692
                        return nil
×
2693
                },
×
2694
        )
2695
}
2696

2697
// ChannelView returns the verifiable edge information for each active channel
2698
// within the known channel graph. The set of UTXOs (along with their scripts)
2699
// returned are the ones that need to be watched on chain to detect channel
2700
// closes on the resident blockchain.
2701
//
2702
// NOTE: part of the V1Store interface.
2703
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
2704
        var (
×
2705
                ctx        = context.TODO()
×
2706
                edgePoints []EdgePoint
×
2707
        )
×
2708

×
2709
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2710
                handleChannel := func(_ context.Context,
×
2711
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2712

×
2713
                        pkScript, err := genMultiSigP2WSH(
×
2714
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2715
                        )
×
2716
                        if err != nil {
×
2717
                                return err
×
2718
                        }
×
2719

2720
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2721
                        if err != nil {
×
2722
                                return err
×
2723
                        }
×
2724

2725
                        edgePoints = append(edgePoints, EdgePoint{
×
2726
                                FundingPkScript: pkScript,
×
2727
                                OutPoint:        *op,
×
2728
                        })
×
2729

×
2730
                        return nil
×
2731
                }
2732

2733
                queryFunc := func(ctx context.Context, lastID int64,
×
2734
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2735

×
2736
                        return db.ListChannelsPaginated(
×
2737
                                ctx, sqlc.ListChannelsPaginatedParams{
×
NEW
2738
                                        Version: int16(lnwire.GossipVersion1),
×
2739
                                        ID:      lastID,
×
2740
                                        Limit:   limit,
×
2741
                                },
×
2742
                        )
×
2743
                }
×
2744

2745
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2746
                        return row.ID
×
2747
                }
×
2748

2749
                return sqldb.ExecutePaginatedQuery(
×
2750
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2751
                        extractCursor, handleChannel,
×
2752
                )
×
2753
        }, func() {
×
2754
                edgePoints = nil
×
2755
        })
×
2756
        if err != nil {
×
2757
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2758
        }
×
2759

2760
        return edgePoints, nil
×
2761
}
2762

2763
// PruneTip returns the block height and hash of the latest block that has been
2764
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2765
// to tell if the graph is currently in sync with the current best known UTXO
2766
// state.
2767
//
2768
// NOTE: part of the V1Store interface.
2769
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
2770
        var (
×
2771
                ctx       = context.TODO()
×
2772
                tipHash   chainhash.Hash
×
2773
                tipHeight uint32
×
2774
        )
×
2775
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2776
                pruneTip, err := db.GetPruneTip(ctx)
×
2777
                if errors.Is(err, sql.ErrNoRows) {
×
2778
                        return ErrGraphNeverPruned
×
2779
                } else if err != nil {
×
2780
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2781
                }
×
2782

2783
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2784
                tipHeight = uint32(pruneTip.BlockHeight)
×
2785

×
2786
                return nil
×
2787
        }, sqldb.NoOpReset)
2788
        if err != nil {
×
2789
                return nil, 0, err
×
2790
        }
×
2791

2792
        return &tipHash, tipHeight, nil
×
2793
}
2794

2795
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2796
//
2797
// NOTE: this prunes nodes across protocol versions. It will never prune the
2798
// source nodes.
2799
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
2800
        db SQLQueries) ([]route.Vertex, error) {
×
2801

×
2802
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
2803
        if err != nil {
×
2804
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2805
                        "nodes: %w", err)
×
2806
        }
×
2807

2808
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2809
        for i, nodeKey := range nodeKeys {
×
2810
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2811
                if err != nil {
×
2812
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2813
                                "from bytes: %w", err)
×
2814
                }
×
2815

2816
                prunedNodes[i] = pub
×
2817
        }
2818

2819
        return prunedNodes, nil
×
2820
}
2821

2822
// DisconnectBlockAtHeight is used to indicate that the block specified
2823
// by the passed height has been disconnected from the main chain. This
2824
// will "rewind" the graph back to the height below, deleting channels
2825
// that are no longer confirmed from the graph. The prune log will be
2826
// set to the last prune height valid for the remaining chain.
2827
// Channels that were removed from the graph resulting from the
2828
// disconnected block are returned.
2829
//
2830
// NOTE: part of the V1Store interface.
2831
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
2832
        []*models.ChannelEdgeInfo, error) {
×
2833

×
2834
        ctx := context.TODO()
×
2835

×
2836
        var (
×
2837
                // Every channel having a ShortChannelID starting at 'height'
×
2838
                // will no longer be confirmed.
×
2839
                startShortChanID = lnwire.ShortChannelID{
×
2840
                        BlockHeight: height,
×
2841
                }
×
2842

×
2843
                // Delete everything after this height from the db up until the
×
2844
                // SCID alias range.
×
2845
                endShortChanID = aliasmgr.StartingAlias
×
2846

×
2847
                removedChans []*models.ChannelEdgeInfo
×
2848

×
2849
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2850
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2851
        )
×
2852

×
2853
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2854
                rows, err := db.GetChannelsBySCIDRange(
×
2855
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2856
                                StartScid: chanIDStart,
×
2857
                                EndScid:   chanIDEnd,
×
2858
                        },
×
2859
                )
×
2860
                if err != nil {
×
2861
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2862
                }
×
2863

2864
                if len(rows) == 0 {
×
2865
                        // No channels to disconnect, but still clean up prune
×
2866
                        // log.
×
2867
                        return db.DeletePruneLogEntriesInRange(
×
2868
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2869
                                        StartHeight: int64(height),
×
2870
                                        EndHeight: int64(
×
2871
                                                endShortChanID.BlockHeight,
×
2872
                                        ),
×
2873
                                },
×
2874
                        )
×
2875
                }
×
2876

2877
                // Batch build all channel edges for disconnection.
2878
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2879
                        ctx, s.cfg, db, rows,
×
2880
                )
×
2881
                if err != nil {
×
2882
                        return err
×
2883
                }
×
2884

2885
                removedChans = channelEdges
×
2886

×
2887
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2888
                if err != nil {
×
2889
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2890
                }
×
2891

2892
                return db.DeletePruneLogEntriesInRange(
×
2893
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2894
                                StartHeight: int64(height),
×
2895
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
2896
                        },
×
2897
                )
×
2898
        }, func() {
×
2899
                removedChans = nil
×
2900
        })
×
2901
        if err != nil {
×
2902
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2903
                        "height: %w", err)
×
2904
        }
×
2905

2906
        for _, channel := range removedChans {
×
2907
                s.rejectCache.remove(channel.ChannelID)
×
2908
                s.chanCache.remove(channel.ChannelID)
×
2909
        }
×
2910

2911
        return removedChans, nil
×
2912
}
2913

2914
// AddEdgeProof sets the proof of an existing edge in the graph database.
2915
//
2916
// NOTE: part of the V1Store interface.
2917
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
2918
        proof *models.ChannelAuthProof) error {
×
2919

×
2920
        var (
×
2921
                ctx       = context.TODO()
×
2922
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2923
        )
×
2924

×
2925
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2926
                res, err := db.AddV1ChannelProof(
×
2927
                        ctx, sqlc.AddV1ChannelProofParams{
×
2928
                                Scid:              scidBytes,
×
2929
                                Node1Signature:    proof.NodeSig1Bytes,
×
2930
                                Node2Signature:    proof.NodeSig2Bytes,
×
2931
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2932
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2933
                        },
×
2934
                )
×
2935
                if err != nil {
×
2936
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2937
                }
×
2938

2939
                n, err := res.RowsAffected()
×
2940
                if err != nil {
×
2941
                        return err
×
2942
                }
×
2943

2944
                if n == 0 {
×
2945
                        return fmt.Errorf("no rows affected when adding edge "+
×
2946
                                "proof for SCID %v", scid)
×
2947
                } else if n > 1 {
×
2948
                        return fmt.Errorf("multiple rows affected when adding "+
×
2949
                                "edge proof for SCID %v: %d rows affected",
×
2950
                                scid, n)
×
2951
                }
×
2952

2953
                return nil
×
2954
        }, sqldb.NoOpReset)
2955
        if err != nil {
×
2956
                return fmt.Errorf("unable to add edge proof: %w", err)
×
2957
        }
×
2958

2959
        return nil
×
2960
}
2961

2962
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2963
// that we can ignore channel announcements that we know to be closed without
2964
// having to validate them and fetch a block.
2965
//
2966
// NOTE: part of the V1Store interface.
2967
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
2968
        var (
×
2969
                ctx     = context.TODO()
×
2970
                chanIDB = channelIDToBytes(scid.ToUint64())
×
2971
        )
×
2972

×
2973
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2974
                return db.InsertClosedChannel(ctx, chanIDB)
×
2975
        }, sqldb.NoOpReset)
×
2976
}
2977

2978
// IsClosedScid checks whether a channel identified by the passed in scid is
2979
// closed. This helps avoid having to perform expensive validation checks.
2980
//
2981
// NOTE: part of the V1Store interface.
2982
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
2983
        var (
×
2984
                ctx      = context.TODO()
×
2985
                isClosed bool
×
2986
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2987
        )
×
2988
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2989
                var err error
×
2990
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2991
                if err != nil {
×
2992
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2993
                                err)
×
2994
                }
×
2995

2996
                return nil
×
2997
        }, sqldb.NoOpReset)
2998
        if err != nil {
×
2999
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
3000
                        err)
×
3001
        }
×
3002

3003
        return isClosed, nil
×
3004
}
3005

3006
// GraphSession will provide the call-back with access to a NodeTraverser
3007
// instance which can be used to perform queries against the channel graph.
3008
//
3009
// NOTE: part of the V1Store interface.
3010
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
3011
        reset func()) error {
×
3012

×
3013
        var ctx = context.TODO()
×
3014

×
3015
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3016
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3017
        }, reset)
×
3018
}
3019

3020
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3021
// read only transaction for a consistent view of the graph.
3022
type sqlNodeTraverser struct {
3023
        db    SQLQueries
3024
        chain chainhash.Hash
3025
}
3026

3027
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3028
// NodeTraverser interface.
3029
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3030

3031
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3032
func newSQLNodeTraverser(db SQLQueries,
3033
        chain chainhash.Hash) *sqlNodeTraverser {
×
3034

×
3035
        return &sqlNodeTraverser{
×
3036
                db:    db,
×
3037
                chain: chain,
×
3038
        }
×
3039
}
×
3040

3041
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3042
// node.
3043
//
3044
// NOTE: Part of the NodeTraverser interface.
3045
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
3046
        cb func(channel *DirectedChannel) error, _ func()) error {
×
3047

×
3048
        ctx := context.TODO()
×
3049

×
3050
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3051
}
×
3052

3053
// FetchNodeFeatures returns the features of the given node. If the node is
3054
// unknown, assume no additional features are supported.
3055
//
3056
// NOTE: Part of the NodeTraverser interface.
3057
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
3058
        *lnwire.FeatureVector, error) {
×
3059

×
3060
        ctx := context.TODO()
×
3061

×
3062
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3063
}
×
3064

3065
// forEachNodeDirectedChannel iterates through all channels of a given
3066
// node, executing the passed callback on the directed edge representing the
3067
// channel and its incoming policy. If the node is not found, no error is
3068
// returned.
3069
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3070
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
3071

×
3072
        toNodeCallback := func() route.Vertex {
×
3073
                return nodePub
×
3074
        }
×
3075

3076
        dbID, err := db.GetNodeIDByPubKey(
×
3077
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
NEW
3078
                        Version: int16(lnwire.GossipVersion1),
×
3079
                        PubKey:  nodePub[:],
×
3080
                },
×
3081
        )
×
3082
        if errors.Is(err, sql.ErrNoRows) {
×
3083
                return nil
×
3084
        } else if err != nil {
×
3085
                return fmt.Errorf("unable to fetch node: %w", err)
×
3086
        }
×
3087

3088
        rows, err := db.ListChannelsByNodeID(
×
3089
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
3090
                        Version: int16(lnwire.GossipVersion1),
×
3091
                        NodeID1: dbID,
×
3092
                },
×
3093
        )
×
3094
        if err != nil {
×
3095
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3096
        }
×
3097

3098
        // Exit early if there are no channels for this node so we don't
3099
        // do the unnecessary feature fetching.
3100
        if len(rows) == 0 {
×
3101
                return nil
×
3102
        }
×
3103

3104
        features, err := getNodeFeatures(ctx, db, dbID)
×
3105
        if err != nil {
×
3106
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3107
        }
×
3108

3109
        for _, row := range rows {
×
3110
                node1, node2, err := buildNodeVertices(
×
3111
                        row.Node1Pubkey, row.Node2Pubkey,
×
3112
                )
×
3113
                if err != nil {
×
3114
                        return fmt.Errorf("unable to build node vertices: %w",
×
3115
                                err)
×
3116
                }
×
3117

3118
                edge := buildCacheableChannelInfo(
×
3119
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3120
                        node1, node2,
×
3121
                )
×
3122

×
3123
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3124
                if err != nil {
×
3125
                        return err
×
3126
                }
×
3127

3128
                p1, p2, err := buildCachedChanPolicies(
×
3129
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3130
                )
×
3131
                if err != nil {
×
3132
                        return err
×
3133
                }
×
3134

3135
                // Determine the outgoing and incoming policy for this
3136
                // channel and node combo.
3137
                outPolicy, inPolicy := p1, p2
×
3138
                if p1 != nil && node2 == nodePub {
×
3139
                        outPolicy, inPolicy = p2, p1
×
3140
                } else if p2 != nil && node1 != nodePub {
×
3141
                        outPolicy, inPolicy = p2, p1
×
3142
                }
×
3143

3144
                var cachedInPolicy *models.CachedEdgePolicy
×
3145
                if inPolicy != nil {
×
3146
                        cachedInPolicy = inPolicy
×
3147
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
3148
                        cachedInPolicy.ToNodeFeatures = features
×
3149
                }
×
3150

3151
                directedChannel := &DirectedChannel{
×
3152
                        ChannelID:    edge.ChannelID,
×
3153
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3154
                        OtherNode:    edge.NodeKey2Bytes,
×
3155
                        Capacity:     edge.Capacity,
×
3156
                        OutPolicySet: outPolicy != nil,
×
3157
                        InPolicy:     cachedInPolicy,
×
3158
                }
×
3159
                if outPolicy != nil {
×
3160
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3161
                                directedChannel.InboundFee = fee
×
3162
                        })
×
3163
                }
3164

3165
                if nodePub == edge.NodeKey2Bytes {
×
3166
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3167
                }
×
3168

3169
                if err := cb(directedChannel); err != nil {
×
3170
                        return err
×
3171
                }
×
3172
        }
3173

3174
        return nil
×
3175
}
3176

3177
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3178
// and executes the provided callback for each node. It does so via pagination
3179
// along with batch loading of the node feature bits.
3180
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3181
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
3182
                features *lnwire.FeatureVector) error) error {
×
3183

×
3184
        handleNode := func(_ context.Context,
×
3185
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
3186
                featureBits map[int64][]int) error {
×
3187

×
3188
                fv := lnwire.EmptyFeatureVector()
×
3189
                if features, exists := featureBits[dbNode.ID]; exists {
×
3190
                        for _, bit := range features {
×
3191
                                fv.Set(lnwire.FeatureBit(bit))
×
3192
                        }
×
3193
                }
3194

3195
                var pub route.Vertex
×
3196
                copy(pub[:], dbNode.PubKey)
×
3197

×
3198
                return processNode(dbNode.ID, pub, fv)
×
3199
        }
3200

3201
        queryFunc := func(ctx context.Context, lastID int64,
×
3202
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3203

×
3204
                return db.ListNodeIDsAndPubKeys(
×
3205
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
NEW
3206
                                Version: int16(lnwire.GossipVersion1),
×
3207
                                ID:      lastID,
×
3208
                                Limit:   limit,
×
3209
                        },
×
3210
                )
×
3211
        }
×
3212

3213
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3214
                return row.ID
×
3215
        }
×
3216

3217
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3218
                return node.ID, nil
×
3219
        }
×
3220

3221
        batchQueryFunc := func(ctx context.Context,
×
3222
                nodeIDs []int64) (map[int64][]int, error) {
×
3223

×
3224
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
3225
        }
×
3226

3227
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3228
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3229
                batchQueryFunc, handleNode,
×
3230
        )
×
3231
}
3232

3233
// forEachNodeChannel iterates through all channels of a node, executing
3234
// the passed callback on each. The call-back is provided with the channel's
3235
// edge information, the outgoing policy and the incoming policy for the
3236
// channel and node combo.
3237
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3238
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3239
                *models.ChannelEdgePolicy,
3240
                *models.ChannelEdgePolicy) error) error {
×
3241

×
3242
        // Get all the V1 channels for this node.
×
3243
        rows, err := db.ListChannelsByNodeID(
×
3244
                ctx, sqlc.ListChannelsByNodeIDParams{
×
NEW
3245
                        Version: int16(lnwire.GossipVersion1),
×
3246
                        NodeID1: id,
×
3247
                },
×
3248
        )
×
3249
        if err != nil {
×
3250
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3251
        }
×
3252

3253
        // Collect all the channel and policy IDs.
3254
        var (
×
3255
                chanIDs   = make([]int64, 0, len(rows))
×
3256
                policyIDs = make([]int64, 0, 2*len(rows))
×
3257
        )
×
3258
        for _, row := range rows {
×
3259
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3260

×
3261
                if row.Policy1ID.Valid {
×
3262
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3263
                }
×
3264
                if row.Policy2ID.Valid {
×
3265
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3266
                }
×
3267
        }
3268

3269
        batchData, err := batchLoadChannelData(
×
3270
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3271
        )
×
3272
        if err != nil {
×
3273
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3274
        }
×
3275

3276
        // Call the call-back for each channel and its known policies.
3277
        for _, row := range rows {
×
3278
                node1, node2, err := buildNodeVertices(
×
3279
                        row.Node1Pubkey, row.Node2Pubkey,
×
3280
                )
×
3281
                if err != nil {
×
3282
                        return fmt.Errorf("unable to build node vertices: %w",
×
3283
                                err)
×
3284
                }
×
3285

3286
                edge, err := buildEdgeInfoWithBatchData(
×
3287
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3288
                        batchData,
×
3289
                )
×
3290
                if err != nil {
×
3291
                        return fmt.Errorf("unable to build channel info: %w",
×
3292
                                err)
×
3293
                }
×
3294

3295
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3296
                if err != nil {
×
3297
                        return fmt.Errorf("unable to extract channel "+
×
3298
                                "policies: %w", err)
×
3299
                }
×
3300

3301
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3302
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3303
                )
×
3304
                if err != nil {
×
3305
                        return fmt.Errorf("unable to build channel "+
×
3306
                                "policies: %w", err)
×
3307
                }
×
3308

3309
                // Determine the outgoing and incoming policy for this
3310
                // channel and node combo.
3311
                p1ToNode := row.GraphChannel.NodeID2
×
3312
                p2ToNode := row.GraphChannel.NodeID1
×
3313
                outPolicy, inPolicy := p1, p2
×
3314
                if (p1 != nil && p1ToNode == id) ||
×
3315
                        (p2 != nil && p2ToNode != id) {
×
3316

×
3317
                        outPolicy, inPolicy = p2, p1
×
3318
                }
×
3319

3320
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3321
                        return err
×
3322
                }
×
3323
        }
3324

3325
        return nil
×
3326
}
3327

3328
// updateChanEdgePolicy upserts the channel policy info we have stored for
3329
// a channel we already know of.
3330
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3331
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
3332
        error) {
×
3333

×
3334
        var (
×
3335
                node1Pub, node2Pub route.Vertex
×
3336
                isNode1            bool
×
3337
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3338
        )
×
3339

×
3340
        // Check that this edge policy refers to a channel that we already
×
3341
        // know of. We do this explicitly so that we can return the appropriate
×
3342
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3343
        // abort the transaction which would abort the entire batch.
×
3344
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3345
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3346
                        Scid:    chanIDB,
×
NEW
3347
                        Version: int16(lnwire.GossipVersion1),
×
3348
                },
×
3349
        )
×
3350
        if errors.Is(err, sql.ErrNoRows) {
×
3351
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3352
        } else if err != nil {
×
3353
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3354
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3355
        }
×
3356

3357
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3358
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3359

×
3360
        // Figure out which node this edge is from.
×
3361
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3362
        nodeID := dbChan.NodeID1
×
3363
        if !isNode1 {
×
3364
                nodeID = dbChan.NodeID2
×
3365
        }
×
3366

3367
        var (
×
3368
                inboundBase sql.NullInt64
×
3369
                inboundRate sql.NullInt64
×
3370
        )
×
3371
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3372
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3373
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3374
        })
×
3375

3376
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
NEW
3377
                Version:     int16(lnwire.GossipVersion1),
×
3378
                ChannelID:   dbChan.ID,
×
3379
                NodeID:      nodeID,
×
3380
                Timelock:    int32(edge.TimeLockDelta),
×
3381
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3382
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3383
                MinHtlcMsat: int64(edge.MinHTLC),
×
3384
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3385
                Disabled: sql.NullBool{
×
3386
                        Valid: true,
×
3387
                        Bool:  edge.IsDisabled(),
×
3388
                },
×
3389
                MaxHtlcMsat: sql.NullInt64{
×
3390
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3391
                        Int64: int64(edge.MaxHTLC),
×
3392
                },
×
3393
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3394
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3395
                InboundBaseFeeMsat:      inboundBase,
×
3396
                InboundFeeRateMilliMsat: inboundRate,
×
3397
                Signature:               edge.SigBytes,
×
3398
        })
×
3399
        if err != nil {
×
3400
                return node1Pub, node2Pub, isNode1,
×
3401
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3402
        }
×
3403

3404
        // Convert the flat extra opaque data into a map of TLV types to
3405
        // values.
3406
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3407
        if err != nil {
×
3408
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3409
                        "marshal extra opaque data: %w", err)
×
3410
        }
×
3411

3412
        // Update the channel policy's extra signed fields.
3413
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3414
        if err != nil {
×
3415
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
3416
                        "policy extra TLVs: %w", err)
×
3417
        }
×
3418

3419
        return node1Pub, node2Pub, isNode1, nil
×
3420
}
3421

3422
// getNodeByPubKey attempts to look up a target node by its public key.
3423
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3424
        pubKey route.Vertex) (int64, *models.Node, error) {
×
3425

×
3426
        dbNode, err := db.GetNodeByPubKey(
×
3427
                ctx, sqlc.GetNodeByPubKeyParams{
×
NEW
3428
                        Version: int16(lnwire.GossipVersion1),
×
3429
                        PubKey:  pubKey[:],
×
3430
                },
×
3431
        )
×
3432
        if errors.Is(err, sql.ErrNoRows) {
×
3433
                return 0, nil, ErrGraphNodeNotFound
×
3434
        } else if err != nil {
×
3435
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3436
        }
×
3437

3438
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3439
        if err != nil {
×
3440
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3441
        }
×
3442

3443
        return dbNode.ID, node, nil
×
3444
}
3445

3446
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3447
// provided parameters.
3448
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
3449
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
3450

×
3451
        return &models.CachedEdgeInfo{
×
3452
                ChannelID:     byteOrder.Uint64(scid),
×
3453
                NodeKey1Bytes: node1Pub,
×
3454
                NodeKey2Bytes: node2Pub,
×
3455
                Capacity:      btcutil.Amount(capacity),
×
3456
        }
×
3457
}
×
3458

3459
// buildNode constructs a Node instance from the given database node
3460
// record. The node's features, addresses and extra signed fields are also
3461
// fetched from the database and set on the node.
3462
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3463
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
3464

×
3465
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
3466
        if err != nil {
×
3467
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3468
                        err)
×
3469
        }
×
3470

3471
        return buildNodeWithBatchData(dbNode, data)
×
3472
}
3473

3474
// buildNodeWithBatchData builds a models.Node instance
3475
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3476
// features/addresses/extra fields, then the corresponding fields are expected
3477
// to be present in the batchNodeData.
3478
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
3479
        batchData *batchNodeData) (*models.Node, error) {
×
3480

×
NEW
3481
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
3482
                return nil, fmt.Errorf("unsupported node version: %d",
×
3483
                        dbNode.Version)
×
3484
        }
×
3485

3486
        var pub [33]byte
×
3487
        copy(pub[:], dbNode.PubKey)
×
3488

×
NEW
3489
        node := models.NewV1ShellNode(pub)
×
3490

×
3491
        if len(dbNode.Signature) == 0 {
×
3492
                return node, nil
×
3493
        }
×
3494

3495
        node.AuthSigBytes = dbNode.Signature
×
NEW
3496

×
NEW
3497
        if dbNode.Alias.Valid {
×
NEW
3498
                node.Alias = fn.Some(dbNode.Alias.String)
×
NEW
3499
        }
×
NEW
3500
        if dbNode.LastUpdate.Valid {
×
NEW
3501
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
NEW
3502
        }
×
3503

3504
        var err error
×
3505
        if dbNode.Color.Valid {
×
NEW
3506
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3507
                if err != nil {
×
3508
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3509
                                err)
×
3510
                }
×
3511

NEW
3512
                node.Color = fn.Some(nodeColor)
×
3513
        }
3514

3515
        // Use preloaded features.
3516
        if features, exists := batchData.features[dbNode.ID]; exists {
×
3517
                fv := lnwire.EmptyFeatureVector()
×
3518
                for _, bit := range features {
×
3519
                        fv.Set(lnwire.FeatureBit(bit))
×
3520
                }
×
3521
                node.Features = fv
×
3522
        }
3523

3524
        // Use preloaded addresses.
3525
        addresses, exists := batchData.addresses[dbNode.ID]
×
3526
        if exists && len(addresses) > 0 {
×
3527
                node.Addresses, err = buildNodeAddresses(addresses)
×
3528
                if err != nil {
×
3529
                        return nil, fmt.Errorf("unable to build addresses "+
×
3530
                                "for node(%d): %w", dbNode.ID, err)
×
3531
                }
×
3532
        }
3533

3534
        // Use preloaded extra fields.
3535
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
3536
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
3537
                if err != nil {
×
3538
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3539
                                "signed fields: %w", err)
×
3540
                }
×
3541
                if len(recs) != 0 {
×
3542
                        node.ExtraOpaqueData = recs
×
3543
                }
×
3544
        }
3545

3546
        return node, nil
×
3547
}
3548

3549
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3550
// with the preloaded data, and executes the provided callback for each node.
3551
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3552
        db SQLQueries, nodes []sqlc.GraphNode,
3553
        cb func(dbID int64, node *models.Node) error) error {
×
3554

×
3555
        // Extract node IDs for batch loading.
×
3556
        nodeIDs := make([]int64, len(nodes))
×
3557
        for i, node := range nodes {
×
3558
                nodeIDs[i] = node.ID
×
3559
        }
×
3560

3561
        // Batch load all related data for this page.
3562
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
3563
        if err != nil {
×
3564
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3565
        }
×
3566

3567
        for _, dbNode := range nodes {
×
3568
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
3569
                if err != nil {
×
3570
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3571
                                dbNode.ID, err)
×
3572
                }
×
3573

3574
                if err := cb(dbNode.ID, node); err != nil {
×
3575
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
3576
                                dbNode.ID, err)
×
3577
                }
×
3578
        }
3579

3580
        return nil
×
3581
}
3582

3583
// getNodeFeatures fetches the feature bits and constructs the feature vector
3584
// for a node with the given DB ID.
3585
func getNodeFeatures(ctx context.Context, db SQLQueries,
3586
        nodeID int64) (*lnwire.FeatureVector, error) {
×
3587

×
3588
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3589
        if err != nil {
×
3590
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3591
                        nodeID, err)
×
3592
        }
×
3593

3594
        features := lnwire.EmptyFeatureVector()
×
3595
        for _, feature := range rows {
×
3596
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3597
        }
×
3598

3599
        return features, nil
×
3600
}
3601

3602
// upsertNode upserts the node record into the database. If the node already
3603
// exists, then the node's information is updated. If the node doesn't exist,
3604
// then a new node is created. The node's features, addresses and extra TLV
3605
// types are also updated. The node's DB ID is returned.
3606
func upsertNode(ctx context.Context, db SQLQueries,
3607
        node *models.Node) (int64, error) {
×
3608

×
3609
        params := sqlc.UpsertNodeParams{
×
NEW
3610
                Version: int16(lnwire.GossipVersion1),
×
3611
                PubKey:  node.PubKeyBytes[:],
×
3612
        }
×
3613

×
NEW
3614
        if node.HaveAnnouncement() {
×
NEW
3615
                switch node.Version {
×
NEW
3616
                case lnwire.GossipVersion1:
×
NEW
3617
                        params.LastUpdate = sqldb.SQLInt64(
×
NEW
3618
                                node.LastUpdate.Unix(),
×
NEW
3619
                        )
×
3620

NEW
3621
                case lnwire.GossipVersion2:
×
3622

NEW
3623
                default:
×
NEW
3624
                        return 0, fmt.Errorf("unknown gossip version: %d",
×
NEW
3625
                                node.Version)
×
3626
                }
3627

NEW
3628
                node.Color.WhenSome(func(rgba color.RGBA) {
×
NEW
3629
                        params.Color = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
NEW
3630
                })
×
NEW
3631
                node.Alias.WhenSome(func(s string) {
×
NEW
3632
                        params.Alias = sqldb.SQLStrValid(s)
×
NEW
3633
                })
×
3634

3635
                params.Signature = node.AuthSigBytes
×
3636
        }
3637

3638
        nodeID, err := db.UpsertNode(ctx, params)
×
3639
        if err != nil {
×
3640
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
3641
                        err)
×
3642
        }
×
3643

3644
        // We can exit here if we don't have the announcement yet.
NEW
3645
        if !node.HaveAnnouncement() {
×
3646
                return nodeID, nil
×
3647
        }
×
3648

3649
        // Update the node's features.
3650
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3651
        if err != nil {
×
3652
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3653
        }
×
3654

3655
        // Update the node's addresses.
3656
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3657
        if err != nil {
×
3658
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3659
        }
×
3660

3661
        // Convert the flat extra opaque data into a map of TLV types to
3662
        // values.
3663
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
3664
        if err != nil {
×
3665
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
3666
                        err)
×
3667
        }
×
3668

3669
        // Update the node's extra signed fields.
3670
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3671
        if err != nil {
×
3672
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3673
        }
×
3674

3675
        return nodeID, nil
×
3676
}
3677

3678
// upsertNodeFeatures updates the node's features node_features table. This
3679
// includes deleting any feature bits no longer present and inserting any new
3680
// feature bits. If the feature bit does not yet exist in the features table,
3681
// then an entry is created in that table first.
3682
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
3683
        features *lnwire.FeatureVector) error {
×
3684

×
3685
        // Get any existing features for the node.
×
3686
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
3687
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3688
                return err
×
3689
        }
×
3690

3691
        // Copy the nodes latest set of feature bits.
3692
        newFeatures := make(map[int32]struct{})
×
3693
        if features != nil {
×
3694
                for feature := range features.Features() {
×
3695
                        newFeatures[int32(feature)] = struct{}{}
×
3696
                }
×
3697
        }
3698

3699
        // For any current feature that already exists in the DB, remove it from
3700
        // the in-memory map. For any existing feature that does not exist in
3701
        // the in-memory map, delete it from the database.
3702
        for _, feature := range existingFeatures {
×
3703
                // The feature is still present, so there are no updates to be
×
3704
                // made.
×
3705
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3706
                        delete(newFeatures, feature.FeatureBit)
×
3707
                        continue
×
3708
                }
3709

3710
                // The feature is no longer present, so we remove it from the
3711
                // database.
3712
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3713
                        NodeID:     nodeID,
×
3714
                        FeatureBit: feature.FeatureBit,
×
3715
                })
×
3716
                if err != nil {
×
3717
                        return fmt.Errorf("unable to delete node(%d) "+
×
3718
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3719
                                err)
×
3720
                }
×
3721
        }
3722

3723
        // Any remaining entries in newFeatures are new features that need to be
3724
        // added to the database for the first time.
3725
        for feature := range newFeatures {
×
3726
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3727
                        NodeID:     nodeID,
×
3728
                        FeatureBit: feature,
×
3729
                })
×
3730
                if err != nil {
×
3731
                        return fmt.Errorf("unable to insert node(%d) "+
×
3732
                                "feature(%v): %w", nodeID, feature, err)
×
3733
                }
×
3734
        }
3735

3736
        return nil
×
3737
}
3738

3739
// fetchNodeFeatures fetches the features for a node with the given public key.
3740
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3741
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3742

×
3743
        rows, err := queries.GetNodeFeaturesByPubKey(
×
3744
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3745
                        PubKey:  nodePub[:],
×
NEW
3746
                        Version: int16(lnwire.GossipVersion1),
×
3747
                },
×
3748
        )
×
3749
        if err != nil {
×
3750
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
3751
                        nodePub, err)
×
3752
        }
×
3753

3754
        features := lnwire.EmptyFeatureVector()
×
3755
        for _, bit := range rows {
×
3756
                features.Set(lnwire.FeatureBit(bit))
×
3757
        }
×
3758

3759
        return features, nil
×
3760
}
3761

3762
// dbAddressType is an enum type that represents the different address types
3763
// that we store in the node_addresses table. The address type determines how
3764
// the address is to be serialised/deserialize.
3765
type dbAddressType uint8
3766

3767
const (
3768
        addressTypeIPv4   dbAddressType = 1
3769
        addressTypeIPv6   dbAddressType = 2
3770
        addressTypeTorV2  dbAddressType = 3
3771
        addressTypeTorV3  dbAddressType = 4
3772
        addressTypeDNS    dbAddressType = 5
3773
        addressTypeOpaque dbAddressType = math.MaxInt8
3774
)
3775

3776
// collectAddressRecords collects the addresses from the provided
3777
// net.Addr slice and returns a map of dbAddressType to a slice of address
3778
// strings.
3779
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
3780
        error) {
×
3781

×
3782
        // Copy the nodes latest set of addresses.
×
3783
        newAddresses := map[dbAddressType][]string{
×
3784
                addressTypeIPv4:   {},
×
3785
                addressTypeIPv6:   {},
×
3786
                addressTypeTorV2:  {},
×
3787
                addressTypeTorV3:  {},
×
3788
                addressTypeDNS:    {},
×
3789
                addressTypeOpaque: {},
×
3790
        }
×
3791
        addAddr := func(t dbAddressType, addr net.Addr) {
×
3792
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3793
        }
×
3794

3795
        for _, address := range addresses {
×
3796
                switch addr := address.(type) {
×
3797
                case *net.TCPAddr:
×
3798
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3799
                                addAddr(addressTypeIPv4, addr)
×
3800
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3801
                                addAddr(addressTypeIPv6, addr)
×
3802
                        } else {
×
3803
                                return nil, fmt.Errorf("unhandled IP "+
×
3804
                                        "address: %v", addr)
×
3805
                        }
×
3806

3807
                case *tor.OnionAddr:
×
3808
                        switch len(addr.OnionService) {
×
3809
                        case tor.V2Len:
×
3810
                                addAddr(addressTypeTorV2, addr)
×
3811
                        case tor.V3Len:
×
3812
                                addAddr(addressTypeTorV3, addr)
×
3813
                        default:
×
3814
                                return nil, fmt.Errorf("invalid length for " +
×
3815
                                        "a tor address")
×
3816
                        }
3817

3818
                case *lnwire.DNSAddress:
×
3819
                        addAddr(addressTypeDNS, addr)
×
3820

3821
                case *lnwire.OpaqueAddrs:
×
3822
                        addAddr(addressTypeOpaque, addr)
×
3823

3824
                default:
×
3825
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3826
                                addr)
×
3827
                }
3828
        }
3829

3830
        return newAddresses, nil
×
3831
}
3832

3833
// upsertNodeAddresses updates the node's addresses in the database. This
3834
// includes deleting any existing addresses and inserting the new set of
3835
// addresses. The deletion is necessary since the ordering of the addresses may
3836
// change, and we need to ensure that the database reflects the latest set of
3837
// addresses so that at the time of reconstructing the node announcement, the
3838
// order is preserved and the signature over the message remains valid.
3839
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3840
        addresses []net.Addr) error {
×
3841

×
3842
        // Delete any existing addresses for the node. This is required since
×
3843
        // even if the new set of addresses is the same, the ordering may have
×
3844
        // changed for a given address type.
×
3845
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3846
        if err != nil {
×
3847
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3848
                        nodeID, err)
×
3849
        }
×
3850

3851
        newAddresses, err := collectAddressRecords(addresses)
×
3852
        if err != nil {
×
3853
                return err
×
3854
        }
×
3855

3856
        // Any remaining entries in newAddresses are new addresses that need to
3857
        // be added to the database for the first time.
3858
        for addrType, addrList := range newAddresses {
×
3859
                for position, addr := range addrList {
×
3860
                        err := db.UpsertNodeAddress(
×
3861
                                ctx, sqlc.UpsertNodeAddressParams{
×
3862
                                        NodeID:   nodeID,
×
3863
                                        Type:     int16(addrType),
×
3864
                                        Address:  addr,
×
3865
                                        Position: int32(position),
×
3866
                                },
×
3867
                        )
×
3868
                        if err != nil {
×
3869
                                return fmt.Errorf("unable to insert "+
×
3870
                                        "node(%d) address(%v): %w", nodeID,
×
3871
                                        addr, err)
×
3872
                        }
×
3873
                }
3874
        }
3875

3876
        return nil
×
3877
}
3878

3879
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3880
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3881
        error) {
×
3882

×
3883
        // GetNodeAddresses ensures that the addresses for a given type are
×
3884
        // returned in the same order as they were inserted.
×
3885
        rows, err := db.GetNodeAddresses(ctx, id)
×
3886
        if err != nil {
×
3887
                return nil, err
×
3888
        }
×
3889

3890
        addresses := make([]net.Addr, 0, len(rows))
×
3891
        for _, row := range rows {
×
3892
                address := row.Address
×
3893

×
3894
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3895
                if err != nil {
×
3896
                        return nil, fmt.Errorf("unable to parse address "+
×
3897
                                "for node(%d): %v: %w", id, address, err)
×
3898
                }
×
3899

3900
                addresses = append(addresses, addr)
×
3901
        }
3902

3903
        // If we have no addresses, then we'll return nil instead of an
3904
        // empty slice.
3905
        if len(addresses) == 0 {
×
3906
                addresses = nil
×
3907
        }
×
3908

3909
        return addresses, nil
×
3910
}
3911

3912
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3913
// database. This includes updating any existing types, inserting any new types,
3914
// and deleting any types that are no longer present.
3915
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
3916
        nodeID int64, extraFields map[uint64][]byte) error {
×
3917

×
3918
        // Get any existing extra signed fields for the node.
×
3919
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3920
        if err != nil {
×
3921
                return err
×
3922
        }
×
3923

3924
        // Make a lookup map of the existing field types so that we can use it
3925
        // to keep track of any fields we should delete.
3926
        m := make(map[uint64]bool)
×
3927
        for _, field := range existingFields {
×
3928
                m[uint64(field.Type)] = true
×
3929
        }
×
3930

3931
        // For all the new fields, we'll upsert them and remove them from the
3932
        // map of existing fields.
3933
        for tlvType, value := range extraFields {
×
3934
                err = db.UpsertNodeExtraType(
×
3935
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
3936
                                NodeID: nodeID,
×
3937
                                Type:   int64(tlvType),
×
3938
                                Value:  value,
×
3939
                        },
×
3940
                )
×
3941
                if err != nil {
×
3942
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3943
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3944
                }
×
3945

3946
                // Remove the field from the map of existing fields if it was
3947
                // present.
3948
                delete(m, tlvType)
×
3949
        }
3950

3951
        // For all the fields that are left in the map of existing fields, we'll
3952
        // delete them as they are no longer present in the new set of fields.
3953
        for tlvType := range m {
×
3954
                err = db.DeleteExtraNodeType(
×
3955
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
3956
                                NodeID: nodeID,
×
3957
                                Type:   int64(tlvType),
×
3958
                        },
×
3959
                )
×
3960
                if err != nil {
×
3961
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
3962
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3963
                }
×
3964
        }
3965

3966
        return nil
×
3967
}
3968

3969
// srcNodeInfo holds the information about the source node of the graph.
3970
type srcNodeInfo struct {
3971
        // id is the DB level ID of the source node entry in the "nodes" table.
3972
        id int64
3973

3974
        // pub is the public key of the source node.
3975
        pub route.Vertex
3976
}
3977

3978
// sourceNode returns the DB node ID and pub key of the source node for the
3979
// specified protocol version.
3980
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
NEW
3981
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
3982

×
3983
        s.srcNodeMu.Lock()
×
3984
        defer s.srcNodeMu.Unlock()
×
3985

×
3986
        // If we already have the source node ID and pub key cached, then
×
3987
        // return them.
×
3988
        if info, ok := s.srcNodes[version]; ok {
×
3989
                return info.id, info.pub, nil
×
3990
        }
×
3991

3992
        var pubKey route.Vertex
×
3993

×
3994
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3995
        if err != nil {
×
3996
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3997
                        err)
×
3998
        }
×
3999

4000
        if len(nodes) == 0 {
×
4001
                return 0, pubKey, ErrSourceNodeNotSet
×
4002
        } else if len(nodes) > 1 {
×
4003
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
4004
                        "protocol %s found", version)
×
4005
        }
×
4006

4007
        copy(pubKey[:], nodes[0].PubKey)
×
4008

×
4009
        s.srcNodes[version] = &srcNodeInfo{
×
4010
                id:  nodes[0].NodeID,
×
4011
                pub: pubKey,
×
4012
        }
×
4013

×
4014
        return nodes[0].NodeID, pubKey, nil
×
4015
}
4016

4017
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4018
// This then produces a map from TLV type to value. If the input is not a
4019
// valid TLV stream, then an error is returned.
4020
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4021
        r := bytes.NewReader(data)
×
4022

×
4023
        tlvStream, err := tlv.NewStream()
×
4024
        if err != nil {
×
4025
                return nil, err
×
4026
        }
×
4027

4028
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4029
        // pass it into the P2P decoding variant.
4030
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
4031
        if err != nil {
×
4032
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
4033
        }
×
4034
        if len(parsedTypes) == 0 {
×
4035
                return nil, nil
×
4036
        }
×
4037

4038
        records := make(map[uint64][]byte)
×
4039
        for k, v := range parsedTypes {
×
4040
                records[uint64(k)] = v
×
4041
        }
×
4042

4043
        return records, nil
×
4044
}
4045

4046
// insertChannel inserts a new channel record into the database.
4047
func insertChannel(ctx context.Context, db SQLQueries,
4048
        edge *models.ChannelEdgeInfo) error {
×
4049

×
4050
        // Make sure that at least a "shell" entry for each node is present in
×
4051
        // the nodes table.
×
4052
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4053
        if err != nil {
×
4054
                return fmt.Errorf("unable to create shell node: %w", err)
×
4055
        }
×
4056

4057
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4058
        if err != nil {
×
4059
                return fmt.Errorf("unable to create shell node: %w", err)
×
4060
        }
×
4061

4062
        var capacity sql.NullInt64
×
4063
        if edge.Capacity != 0 {
×
4064
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4065
        }
×
4066

4067
        createParams := sqlc.CreateChannelParams{
×
NEW
4068
                Version:     int16(lnwire.GossipVersion1),
×
4069
                Scid:        channelIDToBytes(edge.ChannelID),
×
4070
                NodeID1:     node1DBID,
×
4071
                NodeID2:     node2DBID,
×
4072
                Outpoint:    edge.ChannelPoint.String(),
×
4073
                Capacity:    capacity,
×
4074
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4075
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4076
        }
×
4077

×
4078
        if edge.AuthProof != nil {
×
4079
                proof := edge.AuthProof
×
4080

×
4081
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4082
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4083
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4084
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4085
        }
×
4086

4087
        // Insert the new channel record.
4088
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4089
        if err != nil {
×
4090
                return err
×
4091
        }
×
4092

4093
        // Insert any channel features.
4094
        for feature := range edge.Features.Features() {
×
4095
                err = db.InsertChannelFeature(
×
4096
                        ctx, sqlc.InsertChannelFeatureParams{
×
4097
                                ChannelID:  dbChanID,
×
4098
                                FeatureBit: int32(feature),
×
4099
                        },
×
4100
                )
×
4101
                if err != nil {
×
4102
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4103
                                "feature(%v): %w", dbChanID, feature, err)
×
4104
                }
×
4105
        }
4106

4107
        // Finally, insert any extra TLV fields in the channel announcement.
4108
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4109
        if err != nil {
×
4110
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4111
                        err)
×
4112
        }
×
4113

4114
        for tlvType, value := range extra {
×
4115
                err := db.UpsertChannelExtraType(
×
4116
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
4117
                                ChannelID: dbChanID,
×
4118
                                Type:      int64(tlvType),
×
4119
                                Value:     value,
×
4120
                        },
×
4121
                )
×
4122
                if err != nil {
×
4123
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4124
                                "extra signed field(%v): %w", edge.ChannelID,
×
4125
                                tlvType, err)
×
4126
                }
×
4127
        }
4128

4129
        return nil
×
4130
}
4131

4132
// maybeCreateShellNode checks if a shell node entry exists for the
4133
// given public key. If it does not exist, then a new shell node entry is
4134
// created. The ID of the node is returned. A shell node only has a protocol
4135
// version and public key persisted.
4136
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
4137
        pubKey route.Vertex) (int64, error) {
×
4138

×
4139
        dbNode, err := db.GetNodeByPubKey(
×
4140
                ctx, sqlc.GetNodeByPubKeyParams{
×
4141
                        PubKey:  pubKey[:],
×
NEW
4142
                        Version: int16(lnwire.GossipVersion1),
×
4143
                },
×
4144
        )
×
4145
        // The node exists. Return the ID.
×
4146
        if err == nil {
×
4147
                return dbNode.ID, nil
×
4148
        } else if !errors.Is(err, sql.ErrNoRows) {
×
4149
                return 0, err
×
4150
        }
×
4151

4152
        // Otherwise, the node does not exist, so we create a shell entry for
4153
        // it.
4154
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
NEW
4155
                Version: int16(lnwire.GossipVersion1),
×
4156
                PubKey:  pubKey[:],
×
4157
        })
×
4158
        if err != nil {
×
4159
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4160
        }
×
4161

4162
        return id, nil
×
4163
}
4164

4165
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4166
// the database. This includes deleting any existing types and then inserting
4167
// the new types.
4168
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4169
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4170

×
4171
        // Delete all existing extra signed fields for the channel policy.
×
4172
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4173
        if err != nil {
×
4174
                return fmt.Errorf("unable to delete "+
×
4175
                        "existing policy extra signed fields for policy %d: %w",
×
4176
                        chanPolicyID, err)
×
4177
        }
×
4178

4179
        // Insert all new extra signed fields for the channel policy.
4180
        for tlvType, value := range extraFields {
×
4181
                err = db.UpsertChanPolicyExtraType(
×
4182
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
4183
                                ChannelPolicyID: chanPolicyID,
×
4184
                                Type:            int64(tlvType),
×
4185
                                Value:           value,
×
4186
                        },
×
4187
                )
×
4188
                if err != nil {
×
4189
                        return fmt.Errorf("unable to insert "+
×
4190
                                "channel_policy(%d) extra signed field(%v): %w",
×
4191
                                chanPolicyID, tlvType, err)
×
4192
                }
×
4193
        }
4194

4195
        return nil
×
4196
}
4197

4198
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4199
// provided dbChanRow and also fetches any other required information
4200
// to construct the edge info.
4201
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4202
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
4203
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
4204

×
4205
        data, err := batchLoadChannelData(
×
4206
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4207
        )
×
4208
        if err != nil {
×
4209
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4210
                        err)
×
4211
        }
×
4212

4213
        return buildEdgeInfoWithBatchData(
×
4214
                cfg.ChainHash, dbChan, node1, node2, data,
×
4215
        )
×
4216
}
4217

4218
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4219
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4220
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
4221
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4222

×
NEW
4223
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4224
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4225
                        dbChan.Version)
×
4226
        }
×
4227

4228
        // Use pre-loaded features and extras types.
4229
        fv := lnwire.EmptyFeatureVector()
×
4230
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4231
                for _, bit := range features {
×
4232
                        fv.Set(lnwire.FeatureBit(bit))
×
4233
                }
×
4234
        }
4235

4236
        var extras map[uint64][]byte
×
4237
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4238
        if exists {
×
4239
                extras = channelExtras
×
4240
        } else {
×
4241
                extras = make(map[uint64][]byte)
×
4242
        }
×
4243

4244
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4245
        if err != nil {
×
4246
                return nil, err
×
4247
        }
×
4248

4249
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4250
        if err != nil {
×
4251
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4252
                        "fields: %w", err)
×
4253
        }
×
4254
        if recs == nil {
×
4255
                recs = make([]byte, 0)
×
4256
        }
×
4257

4258
        var btcKey1, btcKey2 route.Vertex
×
4259
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4260
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4261

×
4262
        channel := &models.ChannelEdgeInfo{
×
4263
                ChainHash:        chain,
×
4264
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4265
                NodeKey1Bytes:    node1,
×
4266
                NodeKey2Bytes:    node2,
×
4267
                BitcoinKey1Bytes: btcKey1,
×
4268
                BitcoinKey2Bytes: btcKey2,
×
4269
                ChannelPoint:     *op,
×
4270
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4271
                Features:         fv,
×
4272
                ExtraOpaqueData:  recs,
×
4273
        }
×
4274

×
4275
        // We always set all the signatures at the same time, so we can
×
4276
        // safely check if one signature is present to determine if we have the
×
4277
        // rest of the signatures for the auth proof.
×
4278
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4279
                channel.AuthProof = &models.ChannelAuthProof{
×
4280
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4281
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4282
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4283
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4284
                }
×
4285
        }
×
4286

4287
        return channel, nil
×
4288
}
4289

4290
// buildNodeVertices is a helper that converts raw node public keys
4291
// into route.Vertex instances.
4292
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4293
        route.Vertex, error) {
×
4294

×
4295
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
4296
        if err != nil {
×
4297
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4298
                        "create vertex from node1 pubkey: %w", err)
×
4299
        }
×
4300

4301
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
4302
        if err != nil {
×
4303
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4304
                        "create vertex from node2 pubkey: %w", err)
×
4305
        }
×
4306

4307
        return node1Vertex, node2Vertex, nil
×
4308
}
4309

4310
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4311
// retrieves all the extra info required to build the complete
4312
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4313
// the provided sqlc.GraphChannelPolicy records are nil.
4314
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4315
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4316
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4317
        *models.ChannelEdgePolicy, error) {
×
4318

×
4319
        if dbPol1 == nil && dbPol2 == nil {
×
4320
                return nil, nil, nil
×
4321
        }
×
4322

4323
        var policyIDs = make([]int64, 0, 2)
×
4324
        if dbPol1 != nil {
×
4325
                policyIDs = append(policyIDs, dbPol1.ID)
×
4326
        }
×
4327
        if dbPol2 != nil {
×
4328
                policyIDs = append(policyIDs, dbPol2.ID)
×
4329
        }
×
4330

4331
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4332
        if err != nil {
×
4333
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4334
                        "data: %w", err)
×
4335
        }
×
4336

4337
        pol1, err := buildChanPolicyWithBatchData(
×
4338
                dbPol1, channelID, node2, batchData,
×
4339
        )
×
4340
        if err != nil {
×
4341
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
4342
        }
×
4343

4344
        pol2, err := buildChanPolicyWithBatchData(
×
4345
                dbPol2, channelID, node1, batchData,
×
4346
        )
×
4347
        if err != nil {
×
4348
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4349
        }
×
4350

4351
        return pol1, pol2, nil
×
4352
}
4353

4354
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4355
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4356
// then nil is returned for it.
4357
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4358
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4359
        *models.CachedEdgePolicy, error) {
×
4360

×
4361
        var p1, p2 *models.CachedEdgePolicy
×
4362
        if dbPol1 != nil {
×
4363
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
4364
                if err != nil {
×
4365
                        return nil, nil, err
×
4366
                }
×
4367

4368
                p1 = models.NewCachedPolicy(policy1)
×
4369
        }
4370
        if dbPol2 != nil {
×
4371
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4372
                if err != nil {
×
4373
                        return nil, nil, err
×
4374
                }
×
4375

4376
                p2 = models.NewCachedPolicy(policy2)
×
4377
        }
4378

4379
        return p1, p2, nil
×
4380
}
4381

4382
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4383
// provided sqlc.GraphChannelPolicy and other required information.
4384
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4385
        extras map[uint64][]byte,
4386
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4387

×
4388
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4389
        if err != nil {
×
4390
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4391
                        "fields: %w", err)
×
4392
        }
×
4393

4394
        var inboundFee fn.Option[lnwire.Fee]
×
4395
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4396
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4397

×
4398
                inboundFee = fn.Some(lnwire.Fee{
×
4399
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4400
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4401
                })
×
4402
        }
×
4403

4404
        return &models.ChannelEdgePolicy{
×
4405
                SigBytes:  dbPolicy.Signature,
×
4406
                ChannelID: channelID,
×
4407
                LastUpdate: time.Unix(
×
4408
                        dbPolicy.LastUpdate.Int64, 0,
×
4409
                ),
×
4410
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4411
                        dbPolicy.MessageFlags,
×
4412
                ),
×
4413
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4414
                        dbPolicy.ChannelFlags,
×
4415
                ),
×
4416
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
4417
                MinHTLC: lnwire.MilliSatoshi(
×
4418
                        dbPolicy.MinHtlcMsat,
×
4419
                ),
×
4420
                MaxHTLC: lnwire.MilliSatoshi(
×
4421
                        dbPolicy.MaxHtlcMsat.Int64,
×
4422
                ),
×
4423
                FeeBaseMSat: lnwire.MilliSatoshi(
×
4424
                        dbPolicy.BaseFeeMsat,
×
4425
                ),
×
4426
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4427
                ToNode:                    toNode,
×
4428
                InboundFee:                inboundFee,
×
4429
                ExtraOpaqueData:           recs,
×
4430
        }, nil
×
4431
}
4432

4433
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4434
// row which is expected to be a sqlc type that contains channel policy
4435
// information. It returns two policies, which may be nil if the policy
4436
// information is not present in the row.
4437
//
4438
//nolint:ll,dupl,funlen
4439
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4440
        *sqlc.GraphChannelPolicy, error) {
×
4441

×
4442
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4443
        switch r := row.(type) {
×
4444
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4445
                if r.Policy1Timelock.Valid {
×
4446
                        policy1 = &sqlc.GraphChannelPolicy{
×
4447
                                Timelock:                r.Policy1Timelock.Int32,
×
4448
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4449
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4450
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4451
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4452
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4453
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4454
                                Disabled:                r.Policy1Disabled,
×
4455
                                MessageFlags:            r.Policy1MessageFlags,
×
4456
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4457
                        }
×
4458
                }
×
4459
                if r.Policy2Timelock.Valid {
×
4460
                        policy2 = &sqlc.GraphChannelPolicy{
×
4461
                                Timelock:                r.Policy2Timelock.Int32,
×
4462
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4463
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4464
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4465
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4466
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4467
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4468
                                Disabled:                r.Policy2Disabled,
×
4469
                                MessageFlags:            r.Policy2MessageFlags,
×
4470
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4471
                        }
×
4472
                }
×
4473

4474
                return policy1, policy2, nil
×
4475

4476
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4477
                if r.Policy1ID.Valid {
×
4478
                        policy1 = &sqlc.GraphChannelPolicy{
×
4479
                                ID:                      r.Policy1ID.Int64,
×
4480
                                Version:                 r.Policy1Version.Int16,
×
4481
                                ChannelID:               r.GraphChannel.ID,
×
4482
                                NodeID:                  r.Policy1NodeID.Int64,
×
4483
                                Timelock:                r.Policy1Timelock.Int32,
×
4484
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4485
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4486
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4487
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4488
                                LastUpdate:              r.Policy1LastUpdate,
×
4489
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4490
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4491
                                Disabled:                r.Policy1Disabled,
×
4492
                                MessageFlags:            r.Policy1MessageFlags,
×
4493
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4494
                                Signature:               r.Policy1Signature,
×
4495
                        }
×
4496
                }
×
4497
                if r.Policy2ID.Valid {
×
4498
                        policy2 = &sqlc.GraphChannelPolicy{
×
4499
                                ID:                      r.Policy2ID.Int64,
×
4500
                                Version:                 r.Policy2Version.Int16,
×
4501
                                ChannelID:               r.GraphChannel.ID,
×
4502
                                NodeID:                  r.Policy2NodeID.Int64,
×
4503
                                Timelock:                r.Policy2Timelock.Int32,
×
4504
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4505
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4506
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4507
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4508
                                LastUpdate:              r.Policy2LastUpdate,
×
4509
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4510
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4511
                                Disabled:                r.Policy2Disabled,
×
4512
                                MessageFlags:            r.Policy2MessageFlags,
×
4513
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4514
                                Signature:               r.Policy2Signature,
×
4515
                        }
×
4516
                }
×
4517

4518
                return policy1, policy2, nil
×
4519

4520
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4521
                if r.Policy1ID.Valid {
×
4522
                        policy1 = &sqlc.GraphChannelPolicy{
×
4523
                                ID:                      r.Policy1ID.Int64,
×
4524
                                Version:                 r.Policy1Version.Int16,
×
4525
                                ChannelID:               r.GraphChannel.ID,
×
4526
                                NodeID:                  r.Policy1NodeID.Int64,
×
4527
                                Timelock:                r.Policy1Timelock.Int32,
×
4528
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4529
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4530
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4531
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4532
                                LastUpdate:              r.Policy1LastUpdate,
×
4533
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4534
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4535
                                Disabled:                r.Policy1Disabled,
×
4536
                                MessageFlags:            r.Policy1MessageFlags,
×
4537
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4538
                                Signature:               r.Policy1Signature,
×
4539
                        }
×
4540
                }
×
4541
                if r.Policy2ID.Valid {
×
4542
                        policy2 = &sqlc.GraphChannelPolicy{
×
4543
                                ID:                      r.Policy2ID.Int64,
×
4544
                                Version:                 r.Policy2Version.Int16,
×
4545
                                ChannelID:               r.GraphChannel.ID,
×
4546
                                NodeID:                  r.Policy2NodeID.Int64,
×
4547
                                Timelock:                r.Policy2Timelock.Int32,
×
4548
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4549
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4550
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4551
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4552
                                LastUpdate:              r.Policy2LastUpdate,
×
4553
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4554
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4555
                                Disabled:                r.Policy2Disabled,
×
4556
                                MessageFlags:            r.Policy2MessageFlags,
×
4557
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4558
                                Signature:               r.Policy2Signature,
×
4559
                        }
×
4560
                }
×
4561

4562
                return policy1, policy2, nil
×
4563

4564
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4565
                if r.Policy1ID.Valid {
×
4566
                        policy1 = &sqlc.GraphChannelPolicy{
×
4567
                                ID:                      r.Policy1ID.Int64,
×
4568
                                Version:                 r.Policy1Version.Int16,
×
4569
                                ChannelID:               r.GraphChannel.ID,
×
4570
                                NodeID:                  r.Policy1NodeID.Int64,
×
4571
                                Timelock:                r.Policy1Timelock.Int32,
×
4572
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4573
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4574
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4575
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4576
                                LastUpdate:              r.Policy1LastUpdate,
×
4577
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4578
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4579
                                Disabled:                r.Policy1Disabled,
×
4580
                                MessageFlags:            r.Policy1MessageFlags,
×
4581
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4582
                                Signature:               r.Policy1Signature,
×
4583
                        }
×
4584
                }
×
4585
                if r.Policy2ID.Valid {
×
4586
                        policy2 = &sqlc.GraphChannelPolicy{
×
4587
                                ID:                      r.Policy2ID.Int64,
×
4588
                                Version:                 r.Policy2Version.Int16,
×
4589
                                ChannelID:               r.GraphChannel.ID,
×
4590
                                NodeID:                  r.Policy2NodeID.Int64,
×
4591
                                Timelock:                r.Policy2Timelock.Int32,
×
4592
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4593
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4594
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4595
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4596
                                LastUpdate:              r.Policy2LastUpdate,
×
4597
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4598
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4599
                                Disabled:                r.Policy2Disabled,
×
4600
                                MessageFlags:            r.Policy2MessageFlags,
×
4601
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4602
                                Signature:               r.Policy2Signature,
×
4603
                        }
×
4604
                }
×
4605

4606
                return policy1, policy2, nil
×
4607

4608
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4609
                if r.Policy1ID.Valid {
×
4610
                        policy1 = &sqlc.GraphChannelPolicy{
×
4611
                                ID:                      r.Policy1ID.Int64,
×
4612
                                Version:                 r.Policy1Version.Int16,
×
4613
                                ChannelID:               r.GraphChannel.ID,
×
4614
                                NodeID:                  r.Policy1NodeID.Int64,
×
4615
                                Timelock:                r.Policy1Timelock.Int32,
×
4616
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4617
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4618
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4619
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4620
                                LastUpdate:              r.Policy1LastUpdate,
×
4621
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4622
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4623
                                Disabled:                r.Policy1Disabled,
×
4624
                                MessageFlags:            r.Policy1MessageFlags,
×
4625
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4626
                                Signature:               r.Policy1Signature,
×
4627
                        }
×
4628
                }
×
4629
                if r.Policy2ID.Valid {
×
4630
                        policy2 = &sqlc.GraphChannelPolicy{
×
4631
                                ID:                      r.Policy2ID.Int64,
×
4632
                                Version:                 r.Policy2Version.Int16,
×
4633
                                ChannelID:               r.GraphChannel.ID,
×
4634
                                NodeID:                  r.Policy2NodeID.Int64,
×
4635
                                Timelock:                r.Policy2Timelock.Int32,
×
4636
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4637
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4638
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4639
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4640
                                LastUpdate:              r.Policy2LastUpdate,
×
4641
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4642
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4643
                                Disabled:                r.Policy2Disabled,
×
4644
                                MessageFlags:            r.Policy2MessageFlags,
×
4645
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4646
                                Signature:               r.Policy2Signature,
×
4647
                        }
×
4648
                }
×
4649

4650
                return policy1, policy2, nil
×
4651

4652
        case sqlc.ListChannelsForNodeIDsRow:
×
4653
                if r.Policy1ID.Valid {
×
4654
                        policy1 = &sqlc.GraphChannelPolicy{
×
4655
                                ID:                      r.Policy1ID.Int64,
×
4656
                                Version:                 r.Policy1Version.Int16,
×
4657
                                ChannelID:               r.GraphChannel.ID,
×
4658
                                NodeID:                  r.Policy1NodeID.Int64,
×
4659
                                Timelock:                r.Policy1Timelock.Int32,
×
4660
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4661
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4662
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4663
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4664
                                LastUpdate:              r.Policy1LastUpdate,
×
4665
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4666
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4667
                                Disabled:                r.Policy1Disabled,
×
4668
                                MessageFlags:            r.Policy1MessageFlags,
×
4669
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4670
                                Signature:               r.Policy1Signature,
×
4671
                        }
×
4672
                }
×
4673
                if r.Policy2ID.Valid {
×
4674
                        policy2 = &sqlc.GraphChannelPolicy{
×
4675
                                ID:                      r.Policy2ID.Int64,
×
4676
                                Version:                 r.Policy2Version.Int16,
×
4677
                                ChannelID:               r.GraphChannel.ID,
×
4678
                                NodeID:                  r.Policy2NodeID.Int64,
×
4679
                                Timelock:                r.Policy2Timelock.Int32,
×
4680
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4681
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4682
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4683
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4684
                                LastUpdate:              r.Policy2LastUpdate,
×
4685
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4686
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4687
                                Disabled:                r.Policy2Disabled,
×
4688
                                MessageFlags:            r.Policy2MessageFlags,
×
4689
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4690
                                Signature:               r.Policy2Signature,
×
4691
                        }
×
4692
                }
×
4693

4694
                return policy1, policy2, nil
×
4695

4696
        case sqlc.ListChannelsByNodeIDRow:
×
4697
                if r.Policy1ID.Valid {
×
4698
                        policy1 = &sqlc.GraphChannelPolicy{
×
4699
                                ID:                      r.Policy1ID.Int64,
×
4700
                                Version:                 r.Policy1Version.Int16,
×
4701
                                ChannelID:               r.GraphChannel.ID,
×
4702
                                NodeID:                  r.Policy1NodeID.Int64,
×
4703
                                Timelock:                r.Policy1Timelock.Int32,
×
4704
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4705
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4706
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4707
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4708
                                LastUpdate:              r.Policy1LastUpdate,
×
4709
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4710
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4711
                                Disabled:                r.Policy1Disabled,
×
4712
                                MessageFlags:            r.Policy1MessageFlags,
×
4713
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4714
                                Signature:               r.Policy1Signature,
×
4715
                        }
×
4716
                }
×
4717
                if r.Policy2ID.Valid {
×
4718
                        policy2 = &sqlc.GraphChannelPolicy{
×
4719
                                ID:                      r.Policy2ID.Int64,
×
4720
                                Version:                 r.Policy2Version.Int16,
×
4721
                                ChannelID:               r.GraphChannel.ID,
×
4722
                                NodeID:                  r.Policy2NodeID.Int64,
×
4723
                                Timelock:                r.Policy2Timelock.Int32,
×
4724
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4725
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4726
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4727
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4728
                                LastUpdate:              r.Policy2LastUpdate,
×
4729
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4730
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4731
                                Disabled:                r.Policy2Disabled,
×
4732
                                MessageFlags:            r.Policy2MessageFlags,
×
4733
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4734
                                Signature:               r.Policy2Signature,
×
4735
                        }
×
4736
                }
×
4737

4738
                return policy1, policy2, nil
×
4739

4740
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4741
                if r.Policy1ID.Valid {
×
4742
                        policy1 = &sqlc.GraphChannelPolicy{
×
4743
                                ID:                      r.Policy1ID.Int64,
×
4744
                                Version:                 r.Policy1Version.Int16,
×
4745
                                ChannelID:               r.GraphChannel.ID,
×
4746
                                NodeID:                  r.Policy1NodeID.Int64,
×
4747
                                Timelock:                r.Policy1Timelock.Int32,
×
4748
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4749
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4750
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4751
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4752
                                LastUpdate:              r.Policy1LastUpdate,
×
4753
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4754
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4755
                                Disabled:                r.Policy1Disabled,
×
4756
                                MessageFlags:            r.Policy1MessageFlags,
×
4757
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4758
                                Signature:               r.Policy1Signature,
×
4759
                        }
×
4760
                }
×
4761
                if r.Policy2ID.Valid {
×
4762
                        policy2 = &sqlc.GraphChannelPolicy{
×
4763
                                ID:                      r.Policy2ID.Int64,
×
4764
                                Version:                 r.Policy2Version.Int16,
×
4765
                                ChannelID:               r.GraphChannel.ID,
×
4766
                                NodeID:                  r.Policy2NodeID.Int64,
×
4767
                                Timelock:                r.Policy2Timelock.Int32,
×
4768
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4769
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4770
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4771
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4772
                                LastUpdate:              r.Policy2LastUpdate,
×
4773
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4774
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4775
                                Disabled:                r.Policy2Disabled,
×
4776
                                MessageFlags:            r.Policy2MessageFlags,
×
4777
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4778
                                Signature:               r.Policy2Signature,
×
4779
                        }
×
4780
                }
×
4781

4782
                return policy1, policy2, nil
×
4783

4784
        case sqlc.GetChannelsByIDsRow:
×
4785
                if r.Policy1ID.Valid {
×
4786
                        policy1 = &sqlc.GraphChannelPolicy{
×
4787
                                ID:                      r.Policy1ID.Int64,
×
4788
                                Version:                 r.Policy1Version.Int16,
×
4789
                                ChannelID:               r.GraphChannel.ID,
×
4790
                                NodeID:                  r.Policy1NodeID.Int64,
×
4791
                                Timelock:                r.Policy1Timelock.Int32,
×
4792
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4793
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4794
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4795
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4796
                                LastUpdate:              r.Policy1LastUpdate,
×
4797
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4798
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4799
                                Disabled:                r.Policy1Disabled,
×
4800
                                MessageFlags:            r.Policy1MessageFlags,
×
4801
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4802
                                Signature:               r.Policy1Signature,
×
4803
                        }
×
4804
                }
×
4805
                if r.Policy2ID.Valid {
×
4806
                        policy2 = &sqlc.GraphChannelPolicy{
×
4807
                                ID:                      r.Policy2ID.Int64,
×
4808
                                Version:                 r.Policy2Version.Int16,
×
4809
                                ChannelID:               r.GraphChannel.ID,
×
4810
                                NodeID:                  r.Policy2NodeID.Int64,
×
4811
                                Timelock:                r.Policy2Timelock.Int32,
×
4812
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4813
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4814
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4815
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4816
                                LastUpdate:              r.Policy2LastUpdate,
×
4817
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4818
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4819
                                Disabled:                r.Policy2Disabled,
×
4820
                                MessageFlags:            r.Policy2MessageFlags,
×
4821
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4822
                                Signature:               r.Policy2Signature,
×
4823
                        }
×
4824
                }
×
4825

4826
                return policy1, policy2, nil
×
4827

4828
        default:
×
4829
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4830
                        "extractChannelPolicies: %T", r)
×
4831
        }
4832
}
4833

4834
// channelIDToBytes converts a channel ID (SCID) to a byte array
4835
// representation.
4836
func channelIDToBytes(channelID uint64) []byte {
×
4837
        var chanIDB [8]byte
×
4838
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4839

×
4840
        return chanIDB[:]
×
4841
}
×
4842

4843
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4844
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4845
        if len(addresses) == 0 {
×
4846
                return nil, nil
×
4847
        }
×
4848

4849
        result := make([]net.Addr, 0, len(addresses))
×
4850
        for _, addr := range addresses {
×
4851
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
4852
                if err != nil {
×
4853
                        return nil, fmt.Errorf("unable to parse address %s "+
×
4854
                                "of type %d: %w", addr.address, addr.addrType,
×
4855
                                err)
×
4856
                }
×
4857
                if netAddr != nil {
×
4858
                        result = append(result, netAddr)
×
4859
                }
×
4860
        }
4861

4862
        // If we have no valid addresses, return nil instead of empty slice.
4863
        if len(result) == 0 {
×
4864
                return nil, nil
×
4865
        }
×
4866

4867
        return result, nil
×
4868
}
4869

4870
// parseAddress parses the given address string based on the address type
4871
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4872
// and opaque addresses.
4873
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4874
        switch addrType {
×
4875
        case addressTypeIPv4:
×
4876
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4877
                if err != nil {
×
4878
                        return nil, err
×
4879
                }
×
4880

4881
                tcp.IP = tcp.IP.To4()
×
4882

×
4883
                return tcp, nil
×
4884

4885
        case addressTypeIPv6:
×
4886
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4887
                if err != nil {
×
4888
                        return nil, err
×
4889
                }
×
4890

4891
                return tcp, nil
×
4892

4893
        case addressTypeTorV3, addressTypeTorV2:
×
4894
                service, portStr, err := net.SplitHostPort(address)
×
4895
                if err != nil {
×
4896
                        return nil, fmt.Errorf("unable to split tor "+
×
4897
                                "address: %v", address)
×
4898
                }
×
4899

4900
                port, err := strconv.Atoi(portStr)
×
4901
                if err != nil {
×
4902
                        return nil, err
×
4903
                }
×
4904

4905
                return &tor.OnionAddr{
×
4906
                        OnionService: service,
×
4907
                        Port:         port,
×
4908
                }, nil
×
4909

4910
        case addressTypeDNS:
×
4911
                hostname, portStr, err := net.SplitHostPort(address)
×
4912
                if err != nil {
×
4913
                        return nil, fmt.Errorf("unable to split DNS "+
×
4914
                                "address: %v", address)
×
4915
                }
×
4916

4917
                port, err := strconv.Atoi(portStr)
×
4918
                if err != nil {
×
4919
                        return nil, err
×
4920
                }
×
4921

4922
                return &lnwire.DNSAddress{
×
4923
                        Hostname: hostname,
×
4924
                        Port:     uint16(port),
×
4925
                }, nil
×
4926

4927
        case addressTypeOpaque:
×
4928
                opaque, err := hex.DecodeString(address)
×
4929
                if err != nil {
×
4930
                        return nil, fmt.Errorf("unable to decode opaque "+
×
4931
                                "address: %v", address)
×
4932
                }
×
4933

4934
                return &lnwire.OpaqueAddrs{
×
4935
                        Payload: opaque,
×
4936
                }, nil
×
4937

4938
        default:
×
4939
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4940
        }
4941
}
4942

4943
// batchNodeData holds all the related data for a batch of nodes.
4944
type batchNodeData struct {
4945
        // features is a map from a DB node ID to the feature bits for that
4946
        // node.
4947
        features map[int64][]int
4948

4949
        // addresses is a map from a DB node ID to the node's addresses.
4950
        addresses map[int64][]nodeAddress
4951

4952
        // extraFields is a map from a DB node ID to the extra signed fields
4953
        // for that node.
4954
        extraFields map[int64]map[uint64][]byte
4955
}
4956

4957
// nodeAddress holds the address type, position and address string for a
4958
// node. This is used to batch the fetching of node addresses.
4959
type nodeAddress struct {
4960
        addrType dbAddressType
4961
        position int32
4962
        address  string
4963
}
4964

4965
// batchLoadNodeData loads all related data for a batch of node IDs using the
4966
// provided SQLQueries interface. It returns a batchNodeData instance containing
4967
// the node features, addresses and extra signed fields.
4968
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4969
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
4970

×
4971
        // Batch load the node features.
×
4972
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4973
        if err != nil {
×
4974
                return nil, fmt.Errorf("unable to batch load node "+
×
4975
                        "features: %w", err)
×
4976
        }
×
4977

4978
        // Batch load the node addresses.
4979
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4980
        if err != nil {
×
4981
                return nil, fmt.Errorf("unable to batch load node "+
×
4982
                        "addresses: %w", err)
×
4983
        }
×
4984

4985
        // Batch load the node extra signed fields.
4986
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
4987
        if err != nil {
×
4988
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4989
                        "signed fields: %w", err)
×
4990
        }
×
4991

4992
        return &batchNodeData{
×
4993
                features:    features,
×
4994
                addresses:   addrs,
×
4995
                extraFields: extraTypes,
×
4996
        }, nil
×
4997
}
4998

4999
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5000
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5001
func batchLoadNodeFeaturesHelper(ctx context.Context,
5002
        cfg *sqldb.QueryConfig, db SQLQueries,
5003
        nodeIDs []int64) (map[int64][]int, error) {
×
5004

×
5005
        features := make(map[int64][]int)
×
5006

×
5007
        return features, sqldb.ExecuteBatchQuery(
×
5008
                ctx, cfg, nodeIDs,
×
5009
                func(id int64) int64 {
×
5010
                        return id
×
5011
                },
×
5012
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
5013
                        error) {
×
5014

×
5015
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
5016
                },
×
5017
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
5018
                        features[feature.NodeID] = append(
×
5019
                                features[feature.NodeID],
×
5020
                                int(feature.FeatureBit),
×
5021
                        )
×
5022

×
5023
                        return nil
×
5024
                },
×
5025
        )
5026
}
5027

5028
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5029
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5030
// node ID to a slice of nodeAddress structs.
5031
func batchLoadNodeAddressesHelper(ctx context.Context,
5032
        cfg *sqldb.QueryConfig, db SQLQueries,
5033
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5034

×
5035
        addrs := make(map[int64][]nodeAddress)
×
5036

×
5037
        return addrs, sqldb.ExecuteBatchQuery(
×
5038
                ctx, cfg, nodeIDs,
×
5039
                func(id int64) int64 {
×
5040
                        return id
×
5041
                },
×
5042
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5043
                        error) {
×
5044

×
5045
                        return db.GetNodeAddressesBatch(ctx, ids)
×
5046
                },
×
5047
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
5048
                        addrs[addr.NodeID] = append(
×
5049
                                addrs[addr.NodeID], nodeAddress{
×
5050
                                        addrType: dbAddressType(addr.Type),
×
5051
                                        position: addr.Position,
×
5052
                                        address:  addr.Address,
×
5053
                                },
×
5054
                        )
×
5055

×
5056
                        return nil
×
5057
                },
×
5058
        )
5059
}
5060

5061
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5062
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5063
// query.
5064
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5065
        cfg *sqldb.QueryConfig, db SQLQueries,
5066
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5067

×
5068
        extraFields := make(map[int64]map[uint64][]byte)
×
5069

×
5070
        callback := func(ctx context.Context,
×
5071
                field sqlc.GraphNodeExtraType) error {
×
5072

×
5073
                if extraFields[field.NodeID] == nil {
×
5074
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5075
                }
×
5076
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
5077

×
5078
                return nil
×
5079
        }
5080

5081
        return extraFields, sqldb.ExecuteBatchQuery(
×
5082
                ctx, cfg, nodeIDs,
×
5083
                func(id int64) int64 {
×
5084
                        return id
×
5085
                },
×
5086
                func(ctx context.Context, ids []int64) (
5087
                        []sqlc.GraphNodeExtraType, error) {
×
5088

×
5089
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5090
                },
×
5091
                callback,
5092
        )
5093
}
5094

5095
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5096
// from the provided sqlc.GraphChannelPolicy records and the
5097
// provided batchChannelData.
5098
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5099
        channelID uint64, node1, node2 route.Vertex,
5100
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5101
        *models.ChannelEdgePolicy, error) {
×
5102

×
5103
        pol1, err := buildChanPolicyWithBatchData(
×
5104
                dbPol1, channelID, node2, batchData,
×
5105
        )
×
5106
        if err != nil {
×
5107
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
5108
        }
×
5109

5110
        pol2, err := buildChanPolicyWithBatchData(
×
5111
                dbPol2, channelID, node1, batchData,
×
5112
        )
×
5113
        if err != nil {
×
5114
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
5115
        }
×
5116

5117
        return pol1, pol2, nil
×
5118
}
5119

5120
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5121
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5122
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5123
        channelID uint64, toNode route.Vertex,
5124
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
5125

×
5126
        if dbPol == nil {
×
5127
                return nil, nil
×
5128
        }
×
5129

5130
        var dbPol1Extras map[uint64][]byte
×
5131
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
5132
                dbPol1Extras = extras
×
5133
        } else {
×
5134
                dbPol1Extras = make(map[uint64][]byte)
×
5135
        }
×
5136

5137
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5138
}
5139

5140
// batchChannelData holds all the related data for a batch of channels.
5141
type batchChannelData struct {
5142
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5143
        chanfeatures map[int64][]int
5144

5145
        // chanExtras is a map from DB channel ID to a map of TLV type to
5146
        // extra signed field bytes.
5147
        chanExtraTypes map[int64]map[uint64][]byte
5148

5149
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5150
        // to extra signed field bytes.
5151
        policyExtras map[int64]map[uint64][]byte
5152
}
5153

5154
// batchLoadChannelData loads all related data for batches of channels and
5155
// policies.
5156
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5157
        db SQLQueries, channelIDs []int64,
5158
        policyIDs []int64) (*batchChannelData, error) {
×
5159

×
5160
        batchData := &batchChannelData{
×
5161
                chanfeatures:   make(map[int64][]int),
×
5162
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5163
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5164
        }
×
5165

×
5166
        // Batch load channel features and extras
×
5167
        var err error
×
5168
        if len(channelIDs) > 0 {
×
5169
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
5170
                        ctx, cfg, db, channelIDs,
×
5171
                )
×
5172
                if err != nil {
×
5173
                        return nil, fmt.Errorf("unable to batch load "+
×
5174
                                "channel features: %w", err)
×
5175
                }
×
5176

5177
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5178
                        ctx, cfg, db, channelIDs,
×
5179
                )
×
5180
                if err != nil {
×
5181
                        return nil, fmt.Errorf("unable to batch load "+
×
5182
                                "channel extras: %w", err)
×
5183
                }
×
5184
        }
5185

5186
        if len(policyIDs) > 0 {
×
5187
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
5188
                        ctx, cfg, db, policyIDs,
×
5189
                )
×
5190
                if err != nil {
×
5191
                        return nil, fmt.Errorf("unable to batch load "+
×
5192
                                "policy extras: %w", err)
×
5193
                }
×
5194
                batchData.policyExtras = policyExtras
×
5195
        }
5196

5197
        return batchData, nil
×
5198
}
5199

5200
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5201
// channel IDs using ExecuteBatchQuery wrapper around the
5202
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5203
// slice of feature bits.
5204
func batchLoadChannelFeaturesHelper(ctx context.Context,
5205
        cfg *sqldb.QueryConfig, db SQLQueries,
5206
        channelIDs []int64) (map[int64][]int, error) {
×
5207

×
5208
        features := make(map[int64][]int)
×
5209

×
5210
        return features, sqldb.ExecuteBatchQuery(
×
5211
                ctx, cfg, channelIDs,
×
5212
                func(id int64) int64 {
×
5213
                        return id
×
5214
                },
×
5215
                func(ctx context.Context,
5216
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
5217

×
5218
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
5219
                },
×
5220
                func(ctx context.Context,
5221
                        feature sqlc.GraphChannelFeature) error {
×
5222

×
5223
                        features[feature.ChannelID] = append(
×
5224
                                features[feature.ChannelID],
×
5225
                                int(feature.FeatureBit),
×
5226
                        )
×
5227

×
5228
                        return nil
×
5229
                },
×
5230
        )
5231
}
5232

5233
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5234
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5235
// query. It returns a map from DB channel ID to a map of TLV type to extra
5236
// signed field bytes.
5237
func batchLoadChannelExtrasHelper(ctx context.Context,
5238
        cfg *sqldb.QueryConfig, db SQLQueries,
5239
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5240

×
5241
        extras := make(map[int64]map[uint64][]byte)
×
5242

×
5243
        cb := func(ctx context.Context,
×
5244
                extra sqlc.GraphChannelExtraType) error {
×
5245

×
5246
                if extras[extra.ChannelID] == nil {
×
5247
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5248
                }
×
5249
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
5250

×
5251
                return nil
×
5252
        }
5253

5254
        return extras, sqldb.ExecuteBatchQuery(
×
5255
                ctx, cfg, channelIDs,
×
5256
                func(id int64) int64 {
×
5257
                        return id
×
5258
                },
×
5259
                func(ctx context.Context,
5260
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5261

×
5262
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5263
                }, cb,
×
5264
        )
5265
}
5266

5267
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5268
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5269
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5270
// a map of TLV type to extra signed field bytes.
5271
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5272
        cfg *sqldb.QueryConfig, db SQLQueries,
5273
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5274

×
5275
        extras := make(map[int64]map[uint64][]byte)
×
5276

×
5277
        return extras, sqldb.ExecuteBatchQuery(
×
5278
                ctx, cfg, policyIDs,
×
5279
                func(id int64) int64 {
×
5280
                        return id
×
5281
                },
×
5282
                func(ctx context.Context, ids []int64) (
5283
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
5284

×
5285
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
5286
                },
×
5287
                func(ctx context.Context,
5288
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
5289

×
5290
                        if extras[row.PolicyID] == nil {
×
5291
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5292
                        }
×
5293
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5294

×
5295
                        return nil
×
5296
                },
5297
        )
5298
}
5299

5300
// forEachNodePaginated executes a paginated query to process each node in the
5301
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5302
// and applies the provided processNode function to each node.
5303
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5304
        db SQLQueries, protocol lnwire.GossipVersion,
5305
        processNode func(context.Context, int64,
5306
                *models.Node) error) error {
×
5307

×
5308
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5309
                limit int32) ([]sqlc.GraphNode, error) {
×
5310

×
5311
                return db.ListNodesPaginated(
×
5312
                        ctx, sqlc.ListNodesPaginatedParams{
×
5313
                                Version: int16(protocol),
×
5314
                                ID:      lastID,
×
5315
                                Limit:   limit,
×
5316
                        },
×
5317
                )
×
5318
        }
×
5319

5320
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5321
                return node.ID
×
5322
        }
×
5323

5324
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5325
                return node.ID, nil
×
5326
        }
×
5327

5328
        batchQueryFunc := func(ctx context.Context,
×
5329
                nodeIDs []int64) (*batchNodeData, error) {
×
5330

×
5331
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5332
        }
×
5333

5334
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5335
                batchData *batchNodeData) error {
×
5336

×
5337
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
5338
                if err != nil {
×
5339
                        return fmt.Errorf("unable to build "+
×
5340
                                "node(id=%d): %w", dbNode.ID, err)
×
5341
                }
×
5342

5343
                return processNode(ctx, dbNode.ID, node)
×
5344
        }
5345

5346
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5347
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5348
                collectFunc, batchQueryFunc, processItem,
×
5349
        )
×
5350
}
5351

5352
// forEachChannelWithPolicies executes a paginated query to process each channel
5353
// with policies in the graph.
5354
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5355
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5356
                *models.ChannelEdgePolicy,
5357
                *models.ChannelEdgePolicy) error) error {
×
5358

×
5359
        type channelBatchIDs struct {
×
5360
                channelID int64
×
5361
                policyIDs []int64
×
5362
        }
×
5363

×
5364
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5365
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5366
                error) {
×
5367

×
5368
                return db.ListChannelsWithPoliciesPaginated(
×
5369
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
NEW
5370
                                Version: int16(lnwire.GossipVersion1),
×
5371
                                ID:      lastID,
×
5372
                                Limit:   limit,
×
5373
                        },
×
5374
                )
×
5375
        }
×
5376

5377
        extractPageCursor := func(
×
5378
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5379

×
5380
                return row.GraphChannel.ID
×
5381
        }
×
5382

5383
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5384
                channelBatchIDs, error) {
×
5385

×
5386
                ids := channelBatchIDs{
×
5387
                        channelID: row.GraphChannel.ID,
×
5388
                }
×
5389

×
5390
                // Extract policy IDs from the row.
×
5391
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5392
                if err != nil {
×
5393
                        return ids, err
×
5394
                }
×
5395

5396
                if dbPol1 != nil {
×
5397
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5398
                }
×
5399
                if dbPol2 != nil {
×
5400
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5401
                }
×
5402

5403
                return ids, nil
×
5404
        }
5405

5406
        batchDataFunc := func(ctx context.Context,
×
5407
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5408

×
5409
                // Separate channel IDs from policy IDs.
×
5410
                var (
×
5411
                        channelIDs = make([]int64, len(allIDs))
×
5412
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5413
                )
×
5414

×
5415
                for i, ids := range allIDs {
×
5416
                        channelIDs[i] = ids.channelID
×
5417
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5418
                }
×
5419

5420
                return batchLoadChannelData(
×
5421
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5422
                )
×
5423
        }
5424

5425
        processItem := func(ctx context.Context,
×
5426
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5427
                batchData *batchChannelData) error {
×
5428

×
5429
                node1, node2, err := buildNodeVertices(
×
5430
                        row.Node1Pubkey, row.Node2Pubkey,
×
5431
                )
×
5432
                if err != nil {
×
5433
                        return err
×
5434
                }
×
5435

5436
                edge, err := buildEdgeInfoWithBatchData(
×
5437
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5438
                        batchData,
×
5439
                )
×
5440
                if err != nil {
×
5441
                        return fmt.Errorf("unable to build channel info: %w",
×
5442
                                err)
×
5443
                }
×
5444

5445
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5446
                if err != nil {
×
5447
                        return err
×
5448
                }
×
5449

5450
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5451
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
5452
                )
×
5453
                if err != nil {
×
5454
                        return err
×
5455
                }
×
5456

5457
                return processChannel(edge, p1, p2)
×
5458
        }
5459

5460
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5461
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5462
                collectFunc, batchDataFunc, processItem,
×
5463
        )
×
5464
}
5465

5466
// buildDirectedChannel builds a DirectedChannel instance from the provided
5467
// data.
5468
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5469
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5470
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5471
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
5472

×
5473
        node1, node2, err := buildNodeVertices(
×
5474
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5475
        )
×
5476
        if err != nil {
×
5477
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
5478
        }
×
5479

5480
        edge, err := buildEdgeInfoWithBatchData(
×
5481
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5482
        )
×
5483
        if err != nil {
×
5484
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5485
        }
×
5486

5487
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
5488
        if err != nil {
×
5489
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5490
                        err)
×
5491
        }
×
5492

5493
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5494
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5495
                channelBatchData,
×
5496
        )
×
5497
        if err != nil {
×
5498
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5499
                        err)
×
5500
        }
×
5501

5502
        // Determine outgoing and incoming policy for this specific node.
5503
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5504
        p2ToNode := channelRow.GraphChannel.NodeID1
×
5505
        outPolicy, inPolicy := p1, p2
×
5506
        if (p1 != nil && p1ToNode == nodeID) ||
×
5507
                (p2 != nil && p2ToNode != nodeID) {
×
5508

×
5509
                outPolicy, inPolicy = p2, p1
×
5510
        }
×
5511

5512
        // Build cached policy.
5513
        var cachedInPolicy *models.CachedEdgePolicy
×
5514
        if inPolicy != nil {
×
5515
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5516
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5517
                cachedInPolicy.ToNodeFeatures = features
×
5518
        }
×
5519

5520
        // Extract inbound fee.
5521
        var inboundFee lnwire.Fee
×
5522
        if outPolicy != nil {
×
5523
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5524
                        inboundFee = fee
×
5525
                })
×
5526
        }
5527

5528
        // Build directed channel.
5529
        directedChannel := &DirectedChannel{
×
5530
                ChannelID:    edge.ChannelID,
×
5531
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
5532
                OtherNode:    edge.NodeKey2Bytes,
×
5533
                Capacity:     edge.Capacity,
×
5534
                OutPolicySet: outPolicy != nil,
×
5535
                InPolicy:     cachedInPolicy,
×
5536
                InboundFee:   inboundFee,
×
5537
        }
×
5538

×
5539
        if nodePub == edge.NodeKey2Bytes {
×
5540
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5541
        }
×
5542

5543
        return directedChannel, nil
×
5544
}
5545

5546
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5547
// provided rows. It uses batch loading for channels, policies, and nodes.
5548
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5549
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5550

×
5551
        var (
×
5552
                channelIDs = make([]int64, len(rows))
×
5553
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5554
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5555

×
5556
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5557
                nodeIDSet = make(map[int64]bool)
×
5558

×
5559
                // edges will hold the final channel edges built from the rows.
×
5560
                edges = make([]ChannelEdge, 0, len(rows))
×
5561
        )
×
5562

×
5563
        // Collect all IDs needed for batch loading.
×
5564
        for i, row := range rows {
×
5565
                channelIDs[i] = row.Channel().ID
×
5566

×
5567
                // Collect policy IDs
×
5568
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5569
                if err != nil {
×
5570
                        return nil, fmt.Errorf("unable to extract channel "+
×
5571
                                "policies: %w", err)
×
5572
                }
×
5573
                if dbPol1 != nil {
×
5574
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5575
                }
×
5576
                if dbPol2 != nil {
×
5577
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5578
                }
×
5579

5580
                var (
×
5581
                        node1ID = row.Node1().ID
×
5582
                        node2ID = row.Node2().ID
×
5583
                )
×
5584

×
5585
                // Collect unique node IDs.
×
5586
                if !nodeIDSet[node1ID] {
×
5587
                        nodeIDs = append(nodeIDs, node1ID)
×
5588
                        nodeIDSet[node1ID] = true
×
5589
                }
×
5590

5591
                if !nodeIDSet[node2ID] {
×
5592
                        nodeIDs = append(nodeIDs, node2ID)
×
5593
                        nodeIDSet[node2ID] = true
×
5594
                }
×
5595
        }
5596

5597
        // Batch the data for all the channels and policies.
5598
        channelBatchData, err := batchLoadChannelData(
×
5599
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5600
        )
×
5601
        if err != nil {
×
5602
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5603
                        "policy data: %w", err)
×
5604
        }
×
5605

5606
        // Batch the data for all the nodes.
5607
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5608
        if err != nil {
×
5609
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5610
                        err)
×
5611
        }
×
5612

5613
        // Build all channel edges using batch data.
5614
        for _, row := range rows {
×
5615
                // Build nodes using batch data.
×
5616
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5617
                if err != nil {
×
5618
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5619
                }
×
5620

5621
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
5622
                if err != nil {
×
5623
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5624
                }
×
5625

5626
                // Build channel info using batch data.
5627
                channel, err := buildEdgeInfoWithBatchData(
×
5628
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5629
                        node2.PubKeyBytes, channelBatchData,
×
5630
                )
×
5631
                if err != nil {
×
5632
                        return nil, fmt.Errorf("unable to build channel "+
×
5633
                                "info: %w", err)
×
5634
                }
×
5635

5636
                // Extract and build policies using batch data.
5637
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5638
                if err != nil {
×
5639
                        return nil, fmt.Errorf("unable to extract channel "+
×
5640
                                "policies: %w", err)
×
5641
                }
×
5642

5643
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5644
                        dbPol1, dbPol2, channel.ChannelID,
×
5645
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
5646
                )
×
5647
                if err != nil {
×
5648
                        return nil, fmt.Errorf("unable to build channel "+
×
5649
                                "policies: %w", err)
×
5650
                }
×
5651

5652
                edges = append(edges, ChannelEdge{
×
5653
                        Info:    channel,
×
5654
                        Policy1: p1,
×
5655
                        Policy2: p2,
×
5656
                        Node1:   node1,
×
5657
                        Node2:   node2,
×
5658
                })
×
5659
        }
5660

5661
        return edges, nil
×
5662
}
5663

5664
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5665
// instances from the provided rows using batch loading for channel data.
5666
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5667
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5668
        []*models.ChannelEdgeInfo, []int64, error) {
×
5669

×
5670
        if len(rows) == 0 {
×
5671
                return nil, nil, nil
×
5672
        }
×
5673

5674
        // Collect all the channel IDs needed for batch loading.
5675
        channelIDs := make([]int64, len(rows))
×
5676
        for i, row := range rows {
×
5677
                channelIDs[i] = row.Channel().ID
×
5678
        }
×
5679

5680
        // Batch load the channel data.
5681
        channelBatchData, err := batchLoadChannelData(
×
5682
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5683
        )
×
5684
        if err != nil {
×
5685
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5686
                        "data: %w", err)
×
5687
        }
×
5688

5689
        // Build all channel edges using batch data.
5690
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5691
        for _, row := range rows {
×
5692
                node1, node2, err := buildNodeVertices(
×
5693
                        row.Node1Pub(), row.Node2Pub(),
×
5694
                )
×
5695
                if err != nil {
×
5696
                        return nil, nil, err
×
5697
                }
×
5698

5699
                // Build channel info using batch data
5700
                info, err := buildEdgeInfoWithBatchData(
×
5701
                        cfg.ChainHash, row.Channel(), node1, node2,
×
5702
                        channelBatchData,
×
5703
                )
×
5704
                if err != nil {
×
5705
                        return nil, nil, err
×
5706
                }
×
5707

5708
                edges = append(edges, info)
×
5709
        }
5710

5711
        return edges, channelIDs, nil
×
5712
}
5713

5714
// handleZombieMarking is a helper function that handles the logic of
5715
// marking a channel as a zombie in the database. It takes into account whether
5716
// we are in strict zombie pruning mode, and adjusts the node public keys
5717
// accordingly based on the last update timestamps of the channel policies.
5718
func handleZombieMarking(ctx context.Context, db SQLQueries,
5719
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
5720
        strictZombiePruning bool, scid uint64) error {
×
5721

×
5722
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5723

×
5724
        if strictZombiePruning {
×
5725
                var e1UpdateTime, e2UpdateTime *time.Time
×
5726
                if row.Policy1LastUpdate.Valid {
×
5727
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5728
                        e1UpdateTime = &e1Time
×
5729
                }
×
5730
                if row.Policy2LastUpdate.Valid {
×
5731
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5732
                        e2UpdateTime = &e2Time
×
5733
                }
×
5734

5735
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
5736
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
5737
                        e2UpdateTime,
×
5738
                )
×
5739
        }
5740

5741
        return db.UpsertZombieChannel(
×
5742
                ctx, sqlc.UpsertZombieChannelParams{
×
NEW
5743
                        Version:  int16(lnwire.GossipVersion1),
×
5744
                        Scid:     channelIDToBytes(scid),
×
5745
                        NodeKey1: nodeKey1[:],
×
5746
                        NodeKey2: nodeKey2[:],
×
5747
                },
×
5748
        )
×
5749
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc