• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lightningnetwork / lnd / 19398014578

16 Nov 2025 12:53AM UTC coverage: 65.213% (+8.4%) from 56.858%
19398014578

Pull #10323

github

web-flow
Merge c98812e55 into 841a29118
Pull Request #10323: walletrpc: add raw_tx field to BumpFee response

29 of 50 new or added lines in 2 files covered. (58.0%)

6215 existing lines in 35 files now uncovered.

137591 of 210988 relevant lines covered (65.21%)

20814.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/graph/db/sql_store.go
1
package graphdb
2

3
import (
4
        "bytes"
5
        "context"
6
        "database/sql"
7
        "encoding/hex"
8
        "errors"
9
        "fmt"
10
        color "image/color"
11
        "iter"
12
        "maps"
13
        "math"
14
        "net"
15
        "slices"
16
        "strconv"
17
        "sync"
18
        "time"
19

20
        "github.com/btcsuite/btcd/btcec/v2"
21
        "github.com/btcsuite/btcd/btcutil"
22
        "github.com/btcsuite/btcd/chaincfg/chainhash"
23
        "github.com/btcsuite/btcd/wire"
24
        "github.com/lightningnetwork/lnd/aliasmgr"
25
        "github.com/lightningnetwork/lnd/batch"
26
        "github.com/lightningnetwork/lnd/fn/v2"
27
        "github.com/lightningnetwork/lnd/graph/db/models"
28
        "github.com/lightningnetwork/lnd/lnwire"
29
        "github.com/lightningnetwork/lnd/routing/route"
30
        "github.com/lightningnetwork/lnd/sqldb"
31
        "github.com/lightningnetwork/lnd/sqldb/sqlc"
32
        "github.com/lightningnetwork/lnd/tlv"
33
        "github.com/lightningnetwork/lnd/tor"
34
)
35

36
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
37
// execute queries against the SQL graph tables.
38
//
39
//nolint:ll,interfacebloat
40
type SQLQueries interface {
41
        /*
42
                Node queries.
43
        */
44
        UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
45
        GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.GraphNode, error)
46
        GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error)
47
        GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
48
        GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error)
49
        ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error)
50
        ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
51
        IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error)
52
        DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error)
53
        DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
54
        DeleteNode(ctx context.Context, id int64) error
55

56
        GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeExtraType, error)
57
        GetNodeExtraTypesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeExtraType, error)
58
        UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
59
        DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
60

61
        UpsertNodeAddress(ctx context.Context, arg sqlc.UpsertNodeAddressParams) error
62
        GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
63
        GetNodeAddressesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress, error)
64
        DeleteNodeAddresses(ctx context.Context, nodeID int64) error
65

66
        InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
67
        GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.GraphNodeFeature, error)
68
        GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature, error)
69
        GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
70
        DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
71

72
        /*
73
                Source node queries.
74
        */
75
        AddSourceNode(ctx context.Context, nodeID int64) error
76
        GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
77

78
        /*
79
                Channel queries.
80
        */
81
        CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
82
        AddV1ChannelProof(ctx context.Context, arg sqlc.AddV1ChannelProofParams) (sql.Result, error)
83
        GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.GraphChannel, error)
84
        GetChannelsBySCIDs(ctx context.Context, arg sqlc.GetChannelsBySCIDsParams) ([]sqlc.GraphChannel, error)
85
        GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
86
        GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
87
        GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
88
        GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
89
        GetChannelsByIDs(ctx context.Context, ids []int64) ([]sqlc.GetChannelsByIDsRow, error)
90
        GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
91
        HighestSCID(ctx context.Context, version int16) ([]byte, error)
92
        ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
93
        ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
94
        ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
95
        ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
96
        ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
97
        GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
98
        GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
99
        GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error)
100
        GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
101
        DeleteChannels(ctx context.Context, ids []int64) error
102

103
        UpsertChannelExtraType(ctx context.Context, arg sqlc.UpsertChannelExtraTypeParams) error
104
        GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
105
        InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
106
        GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
107

108
        /*
109
                Channel Policy table queries.
110
        */
111
        UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
112
        GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.GraphChannelPolicy, error)
113
        GetV1DisabledSCIDs(ctx context.Context) ([][]byte, error)
114

115
        UpsertChanPolicyExtraType(ctx context.Context, arg sqlc.UpsertChanPolicyExtraTypeParams) error
116
        GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
117
        DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
118

119
        /*
120
                Zombie index queries.
121
        */
122
        UpsertZombieChannel(ctx context.Context, arg sqlc.UpsertZombieChannelParams) error
123
        GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.GraphZombieChannel, error)
124
        GetZombieChannelsSCIDs(ctx context.Context, arg sqlc.GetZombieChannelsSCIDsParams) ([]sqlc.GraphZombieChannel, error)
125
        CountZombieChannels(ctx context.Context, version int16) (int64, error)
126
        DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
127
        IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
128

129
        /*
130
                Prune log table queries.
131
        */
132
        GetPruneTip(ctx context.Context) (sqlc.GraphPruneLog, error)
133
        GetPruneHashByHeight(ctx context.Context, blockHeight int64) ([]byte, error)
134
        GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]sqlc.GraphPruneLog, error)
135
        UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
136
        DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
137

138
        /*
139
                Closed SCID table queries.
140
        */
141
        InsertClosedChannel(ctx context.Context, scid []byte) error
142
        IsClosedChannel(ctx context.Context, scid []byte) (bool, error)
143
        GetClosedChannelsSCIDs(ctx context.Context, scids [][]byte) ([][]byte, error)
144

145
        /*
146
                Migration specific queries.
147

148
                NOTE: these should not be used in code other than migrations.
149
                Once sqldbv2 is in place, these can be removed from this struct
150
                as then migrations will have their own dedicated queries
151
                structs.
152
        */
153
        InsertNodeMig(ctx context.Context, arg sqlc.InsertNodeMigParams) (int64, error)
154
        InsertChannelMig(ctx context.Context, arg sqlc.InsertChannelMigParams) (int64, error)
155
        InsertEdgePolicyMig(ctx context.Context, arg sqlc.InsertEdgePolicyMigParams) (int64, error)
156
}
157

158
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
159
// database operations.
160
type BatchedSQLQueries interface {
161
        SQLQueries
162
        sqldb.BatchedTx[SQLQueries]
163
}
164

165
// SQLStore is an implementation of the V1Store interface that uses a SQL
166
// database as the backend.
167
type SQLStore struct {
168
        cfg *SQLStoreConfig
169
        db  BatchedSQLQueries
170

171
        // cacheMu guards all caches (rejectCache and chanCache). If
172
        // this mutex will be acquired at the same time as the DB mutex then
173
        // the cacheMu MUST be acquired first to prevent deadlock.
174
        cacheMu     sync.RWMutex
175
        rejectCache *rejectCache
176
        chanCache   *channelCache
177

178
        chanScheduler batch.Scheduler[SQLQueries]
179
        nodeScheduler batch.Scheduler[SQLQueries]
180

181
        srcNodes  map[lnwire.GossipVersion]*srcNodeInfo
182
        srcNodeMu sync.Mutex
183
}
184

185
// A compile-time assertion to ensure that SQLStore implements the V1Store
186
// interface.
187
var _ V1Store = (*SQLStore)(nil)
188

189
// SQLStoreConfig holds the configuration for the SQLStore.
190
type SQLStoreConfig struct {
191
        // ChainHash is the genesis hash for the chain that all the gossip
192
        // messages in this store are aimed at.
193
        ChainHash chainhash.Hash
194

195
        // QueryConfig holds configuration values for SQL queries.
196
        QueryCfg *sqldb.QueryConfig
197
}
198

199
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
200
// storage backend.
201
func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries,
UNCOV
202
        options ...StoreOptionModifier) (*SQLStore, error) {
×
UNCOV
203

×
UNCOV
204
        opts := DefaultOptions()
×
UNCOV
205
        for _, o := range options {
×
UNCOV
206
                o(opts)
×
UNCOV
207
        }
×
208

UNCOV
209
        if opts.NoMigration {
×
UNCOV
210
                return nil, fmt.Errorf("the NoMigration option is not yet " +
×
UNCOV
211
                        "supported for SQL stores")
×
UNCOV
212
        }
×
213

UNCOV
214
        s := &SQLStore{
×
215
                cfg:         cfg,
×
216
                db:          db,
×
217
                rejectCache: newRejectCache(opts.RejectCacheSize),
×
218
                chanCache:   newChannelCache(opts.ChannelCacheSize),
×
219
                srcNodes:    make(map[lnwire.GossipVersion]*srcNodeInfo),
×
220
        }
×
UNCOV
221

×
222
        s.chanScheduler = batch.NewTimeScheduler(
×
223
                db, &s.cacheMu, opts.BatchCommitInterval,
×
224
        )
×
225
        s.nodeScheduler = batch.NewTimeScheduler(
×
UNCOV
226
                db, nil, opts.BatchCommitInterval,
×
227
        )
×
228

×
229
        return s, nil
×
230
}
231

232
// AddNode adds a vertex/node to the graph database. If the node is not
233
// in the database from before, this will add a new, unconnected one to the
234
// graph. If it is present from before, this will update that node's
235
// information.
236
//
237
// NOTE: part of the V1Store interface.
238
func (s *SQLStore) AddNode(ctx context.Context,
239
        node *models.Node, opts ...batch.SchedulerOption) error {
×
240

×
241
        r := &batch.Request[SQLQueries]{
×
242
                Opts: batch.NewSchedulerOptions(opts...),
×
UNCOV
243
                Do: func(queries SQLQueries) error {
×
UNCOV
244
                        _, err := upsertNode(ctx, queries, node)
×
UNCOV
245

×
UNCOV
246
                        // It is possible that two of the same node
×
UNCOV
247
                        // announcements are both being processed in the same
×
UNCOV
248
                        // batch. This may case the UpsertNode conflict to
×
UNCOV
249
                        // be hit since we require at the db layer that the
×
UNCOV
250
                        // new last_update is greater than the existing
×
UNCOV
251
                        // last_update. We need to gracefully handle this here.
×
252
                        if errors.Is(err, sql.ErrNoRows) {
×
253
                                return nil
×
254
                        }
×
255

256
                        return err
×
257
                },
258
        }
259

260
        return s.nodeScheduler.Execute(ctx, r)
×
261
}
262

263
// FetchNode attempts to look up a target node by its identity public
264
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
265
// returned.
266
//
267
// NOTE: part of the V1Store interface.
268
func (s *SQLStore) FetchNode(ctx context.Context,
269
        pubKey route.Vertex) (*models.Node, error) {
×
UNCOV
270

×
UNCOV
271
        var node *models.Node
×
UNCOV
272
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
273
                var err error
×
UNCOV
274
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, pubKey)
×
UNCOV
275

×
UNCOV
276
                return err
×
UNCOV
277
        }, sqldb.NoOpReset)
×
UNCOV
278
        if err != nil {
×
UNCOV
279
                return nil, fmt.Errorf("unable to fetch node: %w", err)
×
UNCOV
280
        }
×
281

282
        return node, nil
×
283
}
284

285
// HasNode determines if the graph has a vertex identified by the
286
// target node identity public key. If the node exists in the database, a
287
// timestamp of when the data for the node was lasted updated is returned along
288
// with a true boolean. Otherwise, an empty time.Time is returned with a false
289
// boolean.
290
//
291
// NOTE: part of the V1Store interface.
292
func (s *SQLStore) HasNode(ctx context.Context,
293
        pubKey [33]byte) (time.Time, bool, error) {
×
UNCOV
294

×
295
        var (
×
UNCOV
296
                exists     bool
×
UNCOV
297
                lastUpdate time.Time
×
UNCOV
298
        )
×
UNCOV
299
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
300
                dbNode, err := db.GetNodeByPubKey(
×
UNCOV
301
                        ctx, sqlc.GetNodeByPubKeyParams{
×
UNCOV
302
                                Version: int16(lnwire.GossipVersion1),
×
UNCOV
303
                                PubKey:  pubKey[:],
×
UNCOV
304
                        },
×
UNCOV
305
                )
×
306
                if errors.Is(err, sql.ErrNoRows) {
×
307
                        return nil
×
308
                } else if err != nil {
×
309
                        return fmt.Errorf("unable to fetch node: %w", err)
×
310
                }
×
311

312
                exists = true
×
313

×
314
                if dbNode.LastUpdate.Valid {
×
315
                        lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
316
                }
×
317

318
                return nil
×
319
        }, sqldb.NoOpReset)
320
        if err != nil {
×
321
                return time.Time{}, false,
×
322
                        fmt.Errorf("unable to fetch node: %w", err)
×
323
        }
×
324

325
        return lastUpdate, exists, nil
×
326
}
327

328
// AddrsForNode returns all known addresses for the target node public key
329
// that the graph DB is aware of. The returned boolean indicates if the
330
// given node is unknown to the graph DB or not.
331
//
332
// NOTE: part of the V1Store interface.
333
func (s *SQLStore) AddrsForNode(ctx context.Context,
334
        nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
×
335

×
336
        var (
×
UNCOV
337
                addresses []net.Addr
×
338
                known     bool
×
UNCOV
339
        )
×
UNCOV
340
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
341
                // First, check if the node exists and get its DB ID if it
×
UNCOV
342
                // does.
×
UNCOV
343
                dbID, err := db.GetNodeIDByPubKey(
×
UNCOV
344
                        ctx, sqlc.GetNodeIDByPubKeyParams{
×
UNCOV
345
                                Version: int16(lnwire.GossipVersion1),
×
UNCOV
346
                                PubKey:  nodePub.SerializeCompressed(),
×
347
                        },
×
348
                )
×
349
                if errors.Is(err, sql.ErrNoRows) {
×
350
                        return nil
×
351
                }
×
352

353
                known = true
×
354

×
355
                addresses, err = getNodeAddresses(ctx, db, dbID)
×
356
                if err != nil {
×
357
                        return fmt.Errorf("unable to fetch node addresses: %w",
×
358
                                err)
×
359
                }
×
360

361
                return nil
×
362
        }, sqldb.NoOpReset)
363
        if err != nil {
×
364
                return false, nil, fmt.Errorf("unable to get addresses for "+
×
UNCOV
365
                        "node(%x): %w", nodePub.SerializeCompressed(), err)
×
366
        }
×
367

368
        return known, addresses, nil
×
369
}
370

371
// DeleteNode starts a new database transaction to remove a vertex/node
372
// from the database according to the node's public key.
373
//
374
// NOTE: part of the V1Store interface.
375
func (s *SQLStore) DeleteNode(ctx context.Context,
376
        pubKey route.Vertex) error {
×
377

×
378
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
379
                res, err := db.DeleteNodeByPubKey(
×
UNCOV
380
                        ctx, sqlc.DeleteNodeByPubKeyParams{
×
381
                                Version: int16(lnwire.GossipVersion1),
×
UNCOV
382
                                PubKey:  pubKey[:],
×
UNCOV
383
                        },
×
UNCOV
384
                )
×
UNCOV
385
                if err != nil {
×
UNCOV
386
                        return err
×
UNCOV
387
                }
×
388

389
                rows, err := res.RowsAffected()
×
390
                if err != nil {
×
391
                        return err
×
392
                }
×
393

394
                if rows == 0 {
×
395
                        return ErrGraphNodeNotFound
×
396
                } else if rows > 1 {
×
397
                        return fmt.Errorf("deleted %d rows, expected 1", rows)
×
398
                }
×
399

400
                return err
×
401
        }, sqldb.NoOpReset)
402
        if err != nil {
×
403
                return fmt.Errorf("unable to delete node: %w", err)
×
404
        }
×
405

UNCOV
406
        return nil
×
407
}
408

409
// FetchNodeFeatures returns the features of the given node. If no features are
410
// known for the node, an empty feature vector is returned.
411
//
412
// NOTE: this is part of the graphdb.NodeTraverser interface.
413
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
UNCOV
414
        *lnwire.FeatureVector, error) {
×
415

×
416
        ctx := context.TODO()
×
417

×
UNCOV
418
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
419
}
×
420

421
// DisabledChannelIDs returns the channel ids of disabled channels.
422
// A channel is disabled when two of the associated ChanelEdgePolicies
423
// have their disabled bit on.
424
//
425
// NOTE: part of the V1Store interface.
UNCOV
426
func (s *SQLStore) DisabledChannelIDs() ([]uint64, error) {
×
427
        var (
×
428
                ctx     = context.TODO()
×
429
                chanIDs []uint64
×
430
        )
×
431
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
432
                dbChanIDs, err := db.GetV1DisabledSCIDs(ctx)
×
UNCOV
433
                if err != nil {
×
UNCOV
434
                        return fmt.Errorf("unable to fetch disabled "+
×
UNCOV
435
                                "channels: %w", err)
×
UNCOV
436
                }
×
437

UNCOV
438
                chanIDs = fn.Map(dbChanIDs, byteOrder.Uint64)
×
439

×
440
                return nil
×
441
        }, sqldb.NoOpReset)
442
        if err != nil {
×
443
                return nil, fmt.Errorf("unable to fetch disabled channels: %w",
×
444
                        err)
×
445
        }
×
446

447
        return chanIDs, nil
×
448
}
449

450
// LookupAlias attempts to return the alias as advertised by the target node.
451
//
452
// NOTE: part of the V1Store interface.
453
func (s *SQLStore) LookupAlias(ctx context.Context,
UNCOV
454
        pub *btcec.PublicKey) (string, error) {
×
455

×
456
        var alias string
×
457
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
458
                dbNode, err := db.GetNodeByPubKey(
×
UNCOV
459
                        ctx, sqlc.GetNodeByPubKeyParams{
×
460
                                Version: int16(lnwire.GossipVersion1),
×
UNCOV
461
                                PubKey:  pub.SerializeCompressed(),
×
UNCOV
462
                        },
×
UNCOV
463
                )
×
UNCOV
464
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
465
                        return ErrNodeAliasNotFound
×
UNCOV
466
                } else if err != nil {
×
467
                        return fmt.Errorf("unable to fetch node: %w", err)
×
468
                }
×
469

470
                if !dbNode.Alias.Valid {
×
471
                        return ErrNodeAliasNotFound
×
472
                }
×
473

474
                alias = dbNode.Alias.String
×
475

×
476
                return nil
×
477
        }, sqldb.NoOpReset)
478
        if err != nil {
×
479
                return "", fmt.Errorf("unable to look up alias: %w", err)
×
480
        }
×
481

UNCOV
482
        return alias, nil
×
483
}
484

485
// SourceNode returns the source node of the graph. The source node is treated
486
// as the center node within a star-graph. This method may be used to kick off
487
// a path finding algorithm in order to explore the reachability of another
488
// node based off the source node.
489
//
490
// NOTE: part of the V1Store interface.
491
func (s *SQLStore) SourceNode(ctx context.Context) (*models.Node,
492
        error) {
×
493

×
UNCOV
494
        var node *models.Node
×
495
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
496
                _, nodePub, err := s.getSourceNode(
×
UNCOV
497
                        ctx, db, lnwire.GossipVersion1,
×
UNCOV
498
                )
×
UNCOV
499
                if err != nil {
×
UNCOV
500
                        return fmt.Errorf("unable to fetch V1 source node: %w",
×
UNCOV
501
                                err)
×
UNCOV
502
                }
×
503

UNCOV
504
                _, node, err = getNodeByPubKey(ctx, s.cfg.QueryCfg, db, nodePub)
×
505

×
506
                return err
×
507
        }, sqldb.NoOpReset)
508
        if err != nil {
×
509
                return nil, fmt.Errorf("unable to fetch source node: %w", err)
×
510
        }
×
511

512
        return node, nil
×
513
}
514

515
// SetSourceNode sets the source node within the graph database. The source
516
// node is to be used as the center of a star-graph within path finding
517
// algorithms.
518
//
519
// NOTE: part of the V1Store interface.
520
func (s *SQLStore) SetSourceNode(ctx context.Context,
521
        node *models.Node) error {
×
UNCOV
522

×
523
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
UNCOV
524
                id, err := upsertNode(ctx, db, node)
×
UNCOV
525
                if err != nil {
×
UNCOV
526
                        return fmt.Errorf("unable to upsert source node: %w",
×
UNCOV
527
                                err)
×
UNCOV
528
                }
×
529

530
                // Make sure that if a source node for this version is already
531
                // set, then the ID is the same as the one we are about to set.
532
                dbSourceNodeID, _, err := s.getSourceNode(
×
533
                        ctx, db, lnwire.GossipVersion1,
×
534
                )
×
535
                if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
×
536
                        return fmt.Errorf("unable to fetch source node: %w",
×
537
                                err)
×
538
                } else if err == nil {
×
539
                        if dbSourceNodeID != id {
×
UNCOV
540
                                return fmt.Errorf("v1 source node already "+
×
UNCOV
541
                                        "set to a different node: %d vs %d",
×
UNCOV
542
                                        dbSourceNodeID, id)
×
543
                        }
×
544

545
                        return nil
×
546
                }
547

548
                return db.AddSourceNode(ctx, id)
×
549
        }, sqldb.NoOpReset)
550
}
551

552
// NodeUpdatesInHorizon returns all the known lightning node which have an
553
// update timestamp within the passed range. This method can be used by two
554
// nodes to quickly determine if they have the same set of up to date node
555
// announcements.
556
//
557
// NOTE: This is part of the V1Store interface.
558
func (s *SQLStore) NodeUpdatesInHorizon(startTime, endTime time.Time,
UNCOV
559
        opts ...IteratorOption) iter.Seq2[*models.Node, error] {
×
UNCOV
560

×
UNCOV
561
        cfg := defaultIteratorConfig()
×
UNCOV
562
        for _, opt := range opts {
×
UNCOV
563
                opt(cfg)
×
UNCOV
564
        }
×
565

UNCOV
566
        return func(yield func(*models.Node, error) bool) {
×
UNCOV
567
                var (
×
568
                        ctx            = context.TODO()
×
569
                        lastUpdateTime sql.NullInt64
×
570
                        lastPubKey     = make([]byte, 33)
×
571
                        hasMore        = true
×
572
                )
×
573

×
UNCOV
574
                // Each iteration, we'll read a batch amount of nodes, yield
×
575
                // them, then decide is we have more or not.
×
576
                for hasMore {
×
577
                        var batch []*models.Node
×
578

×
579
                        //nolint:ll
×
580
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
581
                                //nolint:ll
×
582
                                params := sqlc.GetNodesByLastUpdateRangeParams{
×
583
                                        StartTime: sqldb.SQLInt64(
×
584
                                                startTime.Unix(),
×
585
                                        ),
×
586
                                        EndTime: sqldb.SQLInt64(
×
587
                                                endTime.Unix(),
×
588
                                        ),
×
589
                                        LastUpdate: lastUpdateTime,
×
590
                                        LastPubKey: lastPubKey,
×
591
                                        OnlyPublic: sql.NullBool{
×
592
                                                Bool:  cfg.iterPublicNodes,
×
593
                                                Valid: true,
×
594
                                        },
×
595
                                        MaxResults: sqldb.SQLInt32(
×
596
                                                cfg.nodeUpdateIterBatchSize,
×
597
                                        ),
×
598
                                }
×
599
                                rows, err := db.GetNodesByLastUpdateRange(
×
600
                                        ctx, params,
×
601
                                )
×
602
                                if err != nil {
×
603
                                        return err
×
604
                                }
×
605

606
                                hasMore = len(rows) == cfg.nodeUpdateIterBatchSize
×
607

×
608
                                err = forEachNodeInBatch(
×
609
                                        ctx, s.cfg.QueryCfg, db, rows,
×
610
                                        func(_ int64, node *models.Node) error {
×
611
                                                batch = append(batch, node)
×
612

×
613
                                                // Update pagination cursors
×
UNCOV
614
                                                // based on the last processed
×
615
                                                // node.
×
616
                                                lastUpdateTime = sql.NullInt64{
×
617
                                                        Int64: node.LastUpdate.
×
618
                                                                Unix(),
×
619
                                                        Valid: true,
×
620
                                                }
×
621
                                                lastPubKey = node.PubKeyBytes[:]
×
622

×
623
                                                return nil
×
624
                                        },
×
625
                                )
626
                                if err != nil {
×
627
                                        return fmt.Errorf("unable to build "+
×
628
                                                "nodes: %w", err)
×
629
                                }
×
630

631
                                return nil
×
632
                        }, func() {
×
633
                                batch = []*models.Node{}
×
UNCOV
634
                        })
×
635

636
                        if err != nil {
×
637
                                log.Errorf("NodeUpdatesInHorizon batch "+
×
638
                                        "error: %v", err)
×
UNCOV
639

×
640
                                yield(&models.Node{}, err)
×
641

×
642
                                return
×
643
                        }
×
644

645
                        for _, node := range batch {
×
646
                                if !yield(node, nil) {
×
647
                                        return
×
648
                                }
×
649
                        }
650

651
                        // If the batch didn't yield anything, then we're done.
652
                        if len(batch) == 0 {
×
UNCOV
653
                                break
×
654
                        }
655
                }
656
        }
657
}
658

659
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
660
// undirected edge from the two target nodes are created. The information stored
661
// denotes the static attributes of the channel, such as the channelID, the keys
662
// involved in creation of the channel, and the set of features that the channel
663
// supports. The chanPoint and chanID are used to uniquely identify the edge
664
// globally within the database.
665
//
666
// NOTE: part of the V1Store interface.
667
func (s *SQLStore) AddChannelEdge(ctx context.Context,
UNCOV
668
        edge *models.ChannelEdgeInfo, opts ...batch.SchedulerOption) error {
×
UNCOV
669

×
UNCOV
670
        var alreadyExists bool
×
UNCOV
671
        r := &batch.Request[SQLQueries]{
×
UNCOV
672
                Opts: batch.NewSchedulerOptions(opts...),
×
UNCOV
673
                Reset: func() {
×
UNCOV
674
                        alreadyExists = false
×
UNCOV
675
                },
×
UNCOV
676
                Do: func(tx SQLQueries) error {
×
677
                        chanIDB := channelIDToBytes(edge.ChannelID)
×
678

×
679
                        // Make sure that the channel doesn't already exist. We
×
680
                        // do this explicitly instead of relying on catching a
×
681
                        // unique constraint error because relying on SQL to
×
682
                        // throw that error would abort the entire batch of
×
683
                        // transactions.
×
684
                        _, err := tx.GetChannelBySCID(
×
685
                                ctx, sqlc.GetChannelBySCIDParams{
×
686
                                        Scid:    chanIDB,
×
687
                                        Version: int16(lnwire.GossipVersion1),
×
688
                                },
×
689
                        )
×
690
                        if err == nil {
×
691
                                alreadyExists = true
×
692
                                return nil
×
693
                        } else if !errors.Is(err, sql.ErrNoRows) {
×
694
                                return fmt.Errorf("unable to fetch channel: %w",
×
695
                                        err)
×
696
                        }
×
697

698
                        return insertChannel(ctx, tx, edge)
×
699
                },
700
                OnCommit: func(err error) error {
×
701
                        switch {
×
702
                        case err != nil:
×
703
                                return err
×
704
                        case alreadyExists:
×
705
                                return ErrEdgeAlreadyExist
×
UNCOV
706
                        default:
×
707
                                s.rejectCache.remove(edge.ChannelID)
×
UNCOV
708
                                s.chanCache.remove(edge.ChannelID)
×
709
                                return nil
×
710
                        }
711
                },
712
        }
713

714
        return s.chanScheduler.Execute(ctx, r)
×
715
}
716

717
// HighestChanID returns the "highest" known channel ID in the channel graph.
718
// This represents the "newest" channel from the PoV of the chain. This method
719
// can be used by peers to quickly determine if their graphs are in sync.
720
//
721
// NOTE: This is part of the V1Store interface.
UNCOV
722
func (s *SQLStore) HighestChanID(ctx context.Context) (uint64, error) {
×
723
        var highestChanID uint64
×
UNCOV
724
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
725
                chanID, err := db.HighestSCID(ctx, int16(lnwire.GossipVersion1))
×
UNCOV
726
                if errors.Is(err, sql.ErrNoRows) {
×
UNCOV
727
                        return nil
×
UNCOV
728
                } else if err != nil {
×
UNCOV
729
                        return fmt.Errorf("unable to fetch highest chan ID: %w",
×
UNCOV
730
                                err)
×
731
                }
×
732

733
                highestChanID = byteOrder.Uint64(chanID)
×
734

×
735
                return nil
×
736
        }, sqldb.NoOpReset)
737
        if err != nil {
×
738
                return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
×
739
        }
×
740

UNCOV
741
        return highestChanID, nil
×
742
}
743

744
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
745
// within the database for the referenced channel. The `flags` attribute within
746
// the ChannelEdgePolicy determines which of the directed edges are being
747
// updated. If the flag is 1, then the first node's information is being
748
// updated, otherwise it's the second node's information. The node ordering is
749
// determined by the lexicographical ordering of the identity public keys of the
750
// nodes on either side of the channel.
751
//
752
// NOTE: part of the V1Store interface.
753
func (s *SQLStore) UpdateEdgePolicy(ctx context.Context,
754
        edge *models.ChannelEdgePolicy,
UNCOV
755
        opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) {
×
UNCOV
756

×
UNCOV
757
        var (
×
UNCOV
758
                isUpdate1    bool
×
UNCOV
759
                edgeNotFound bool
×
UNCOV
760
                from, to     route.Vertex
×
UNCOV
761
        )
×
UNCOV
762

×
UNCOV
763
        r := &batch.Request[SQLQueries]{
×
764
                Opts: batch.NewSchedulerOptions(opts...),
×
765
                Reset: func() {
×
766
                        isUpdate1 = false
×
767
                        edgeNotFound = false
×
768
                },
×
769
                Do: func(tx SQLQueries) error {
×
770
                        var err error
×
771
                        from, to, isUpdate1, err = updateChanEdgePolicy(
×
772
                                ctx, tx, edge,
×
773
                        )
×
774
                        // It is possible that two of the same policy
×
775
                        // announcements are both being processed in the same
×
776
                        // batch. This may case the UpsertEdgePolicy conflict to
×
777
                        // be hit since we require at the db layer that the
×
778
                        // new last_update is greater than the existing
×
779
                        // last_update. We need to gracefully handle this here.
×
780
                        if errors.Is(err, sql.ErrNoRows) {
×
781
                                return nil
×
782
                        } else if err != nil {
×
783
                                log.Errorf("UpdateEdgePolicy faild: %v", err)
×
784
                        }
×
785

786
                        // Silence ErrEdgeNotFound so that the batch can
787
                        // succeed, but propagate the error via local state.
788
                        if errors.Is(err, ErrEdgeNotFound) {
×
789
                                edgeNotFound = true
×
790
                                return nil
×
791
                        }
×
792

793
                        return err
×
794
                },
UNCOV
795
                OnCommit: func(err error) error {
×
UNCOV
796
                        switch {
×
797
                        case err != nil:
×
798
                                return err
×
799
                        case edgeNotFound:
×
800
                                return ErrEdgeNotFound
×
UNCOV
801
                        default:
×
802
                                s.updateEdgeCache(edge, isUpdate1)
×
UNCOV
803
                                return nil
×
804
                        }
805
                },
806
        }
807

808
        err := s.chanScheduler.Execute(ctx, r)
×
809

×
810
        return from, to, err
×
811
}
812

813
// updateEdgeCache updates our reject and channel caches with the new
814
// edge policy information.
815
func (s *SQLStore) updateEdgeCache(e *models.ChannelEdgePolicy,
UNCOV
816
        isUpdate1 bool) {
×
817

×
818
        // If an entry for this channel is found in reject cache, we'll modify
×
819
        // the entry with the updated timestamp for the direction that was just
×
UNCOV
820
        // written. If the edge doesn't exist, we'll load the cache entry lazily
×
UNCOV
821
        // during the next query for this edge.
×
UNCOV
822
        if entry, ok := s.rejectCache.get(e.ChannelID); ok {
×
UNCOV
823
                if isUpdate1 {
×
UNCOV
824
                        entry.upd1Time = e.LastUpdate.Unix()
×
825
                } else {
×
826
                        entry.upd2Time = e.LastUpdate.Unix()
×
827
                }
×
828
                s.rejectCache.insert(e.ChannelID, entry)
×
829
        }
830

831
        // If an entry for this channel is found in channel cache, we'll modify
832
        // the entry with the updated policy for the direction that was just
833
        // written. If the edge doesn't exist, we'll defer loading the info and
834
        // policies and lazily read from disk during the next query.
835
        if channel, ok := s.chanCache.get(e.ChannelID); ok {
×
836
                if isUpdate1 {
×
837
                        channel.Policy1 = e
×
UNCOV
838
                } else {
×
UNCOV
839
                        channel.Policy2 = e
×
UNCOV
840
                }
×
UNCOV
841
                s.chanCache.insert(e.ChannelID, channel)
×
842
        }
843
}
844

845
// ForEachSourceNodeChannel iterates through all channels of the source node,
846
// executing the passed callback on each. The call-back is provided with the
847
// channel's outpoint, whether we have a policy for the channel and the channel
848
// peer's node information.
849
//
850
// NOTE: part of the V1Store interface.
851
func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
852
        cb func(chanPoint wire.OutPoint, havePolicy bool,
UNCOV
853
                otherNode *models.Node) error, reset func()) error {
×
UNCOV
854

×
UNCOV
855
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
856
                nodeID, nodePub, err := s.getSourceNode(
×
UNCOV
857
                        ctx, db, lnwire.GossipVersion1,
×
UNCOV
858
                )
×
UNCOV
859
                if err != nil {
×
UNCOV
860
                        return fmt.Errorf("unable to fetch source node: %w",
×
UNCOV
861
                                err)
×
862
                }
×
863

864
                return forEachNodeChannel(
×
865
                        ctx, db, s.cfg, nodeID,
×
866
                        func(info *models.ChannelEdgeInfo,
×
867
                                outPolicy *models.ChannelEdgePolicy,
×
868
                                _ *models.ChannelEdgePolicy) error {
×
869

×
UNCOV
870
                                // Fetch the other node.
×
871
                                var (
×
872
                                        otherNodePub [33]byte
×
873
                                        node1        = info.NodeKey1Bytes
×
874
                                        node2        = info.NodeKey2Bytes
×
875
                                )
×
876
                                switch {
×
877
                                case bytes.Equal(node1[:], nodePub[:]):
×
878
                                        otherNodePub = node2
×
879
                                case bytes.Equal(node2[:], nodePub[:]):
×
880
                                        otherNodePub = node1
×
881
                                default:
×
882
                                        return fmt.Errorf("node not " +
×
883
                                                "participating in this channel")
×
884
                                }
885

886
                                _, otherNode, err := getNodeByPubKey(
×
887
                                        ctx, s.cfg.QueryCfg, db, otherNodePub,
×
888
                                )
×
889
                                if err != nil {
×
890
                                        return fmt.Errorf("unable to fetch "+
×
UNCOV
891
                                                "other node(%x): %w",
×
UNCOV
892
                                                otherNodePub, err)
×
893
                                }
×
894

895
                                return cb(
×
896
                                        info.ChannelPoint, outPolicy != nil,
×
897
                                        otherNode,
×
898
                                )
×
899
                        },
900
                )
901
        }, reset)
902
}
903

904
// ForEachNode iterates through all the stored vertices/nodes in the graph,
905
// executing the passed callback with each node encountered. If the callback
906
// returns an error, then the transaction is aborted and the iteration stops
907
// early.
908
//
909
// NOTE: part of the V1Store interface.
910
func (s *SQLStore) ForEachNode(ctx context.Context,
UNCOV
911
        cb func(node *models.Node) error, reset func()) error {
×
UNCOV
912

×
UNCOV
913
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
914
                return forEachNodePaginated(
×
UNCOV
915
                        ctx, s.cfg.QueryCfg, db,
×
UNCOV
916
                        lnwire.GossipVersion1, func(_ context.Context, _ int64,
×
UNCOV
917
                                node *models.Node) error {
×
918

×
919
                                return cb(node)
×
920
                        },
×
921
                )
922
        }, reset)
923
}
924

925
// ForEachNodeDirectedChannel iterates through all channels of a given node,
926
// executing the passed callback on the directed edge representing the channel
927
// and its incoming policy. If the callback returns an error, then the iteration
928
// is halted with the error propagated back up to the caller.
929
//
930
// Unknown policies are passed into the callback as nil values.
931
//
932
// NOTE: this is part of the graphdb.NodeTraverser interface.
933
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
UNCOV
934
        cb func(channel *DirectedChannel) error, reset func()) error {
×
UNCOV
935

×
UNCOV
936
        var ctx = context.TODO()
×
UNCOV
937

×
UNCOV
938
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
939
                return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
×
UNCOV
940
        }, reset)
×
941
}
942

943
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
944
// graph, executing the passed callback with each node encountered. If the
945
// callback returns an error, then the transaction is aborted and the iteration
946
// stops early.
947
func (s *SQLStore) ForEachNodeCacheable(ctx context.Context,
948
        cb func(route.Vertex, *lnwire.FeatureVector) error,
UNCOV
949
        reset func()) error {
×
UNCOV
950

×
UNCOV
951
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
952
                return forEachNodeCacheable(
×
UNCOV
953
                        ctx, s.cfg.QueryCfg, db,
×
UNCOV
954
                        func(_ int64, nodePub route.Vertex,
×
UNCOV
955
                                features *lnwire.FeatureVector) error {
×
956

×
957
                                return cb(nodePub, features)
×
958
                        },
×
959
                )
960
        }, reset)
961
        if err != nil {
×
962
                return fmt.Errorf("unable to fetch nodes: %w", err)
×
963
        }
×
964

965
        return nil
×
966
}
967

968
// ForEachNodeChannel iterates through all channels of the given node,
969
// executing the passed callback with an edge info structure and the policies
970
// of each end of the channel. The first edge policy is the outgoing edge *to*
971
// the connecting node, while the second is the incoming edge *from* the
972
// connecting node. If the callback returns an error, then the iteration is
973
// halted with the error propagated back up to the caller.
974
//
975
// Unknown policies are passed into the callback as nil values.
976
//
977
// NOTE: part of the V1Store interface.
978
func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
979
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
UNCOV
980
                *models.ChannelEdgePolicy) error, reset func()) error {
×
UNCOV
981

×
UNCOV
982
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
983
                dbNode, err := db.GetNodeByPubKey(
×
UNCOV
984
                        ctx, sqlc.GetNodeByPubKeyParams{
×
UNCOV
985
                                Version: int16(lnwire.GossipVersion1),
×
UNCOV
986
                                PubKey:  nodePub[:],
×
987
                        },
×
988
                )
×
989
                if errors.Is(err, sql.ErrNoRows) {
×
990
                        return nil
×
991
                } else if err != nil {
×
992
                        return fmt.Errorf("unable to fetch node: %w", err)
×
993
                }
×
994

995
                return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
×
996
        }, reset)
997
}
998

999
// extractMaxUpdateTime returns the maximum of the two policy update times.
1000
// This is used for pagination cursor tracking.
1001
func extractMaxUpdateTime(
1002
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 {
×
UNCOV
1003

×
UNCOV
1004
        switch {
×
UNCOV
1005
        case row.Policy1LastUpdate.Valid && row.Policy2LastUpdate.Valid:
×
UNCOV
1006
                return max(row.Policy1LastUpdate.Int64,
×
UNCOV
1007
                        row.Policy2LastUpdate.Int64)
×
UNCOV
1008
        case row.Policy1LastUpdate.Valid:
×
1009
                return row.Policy1LastUpdate.Int64
×
1010
        case row.Policy2LastUpdate.Valid:
×
1011
                return row.Policy2LastUpdate.Int64
×
1012
        default:
×
1013
                return 0
×
1014
        }
1015
}
1016

1017
// buildChannelFromRow constructs a ChannelEdge from a database row.
1018
// This includes building the nodes, channel info, and policies.
1019
func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries,
1020
        row sqlc.GetChannelsByPolicyLastUpdateRangeRow) (ChannelEdge, error) {
×
UNCOV
1021

×
UNCOV
1022
        node1, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode)
×
UNCOV
1023
        if err != nil {
×
UNCOV
1024
                return ChannelEdge{}, fmt.Errorf("unable to build node1: %w",
×
UNCOV
1025
                        err)
×
UNCOV
1026
        }
×
1027

1028
        node2, err := buildNode(ctx, s.cfg.QueryCfg, db, row.GraphNode_2)
×
1029
        if err != nil {
×
1030
                return ChannelEdge{}, fmt.Errorf("unable to build node2: %w",
×
1031
                        err)
×
1032
        }
×
1033

UNCOV
1034
        channel, err := getAndBuildEdgeInfo(
×
1035
                ctx, s.cfg, db,
×
1036
                row.GraphChannel, node1.PubKeyBytes,
×
1037
                node2.PubKeyBytes,
×
1038
        )
×
1039
        if err != nil {
×
UNCOV
1040
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
1041
                        "channel info: %w", err)
×
1042
        }
×
1043

1044
        dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1045
        if err != nil {
×
1046
                return ChannelEdge{}, fmt.Errorf("unable to extract "+
×
1047
                        "channel policies: %w", err)
×
1048
        }
×
1049

UNCOV
1050
        p1, p2, err := getAndBuildChanPolicies(
×
1051
                ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, channel.ChannelID,
×
1052
                node1.PubKeyBytes, node2.PubKeyBytes,
×
1053
        )
×
1054
        if err != nil {
×
1055
                return ChannelEdge{}, fmt.Errorf("unable to build "+
×
UNCOV
1056
                        "channel policies: %w", err)
×
1057
        }
×
1058

1059
        return ChannelEdge{
×
1060
                Info:    channel,
×
1061
                Policy1: p1,
×
1062
                Policy2: p2,
×
1063
                Node1:   node1,
×
1064
                Node2:   node2,
×
UNCOV
1065
        }, nil
×
1066
}
1067

1068
// updateChanCacheBatch updates the channel cache with multiple edges at once.
1069
// This method acquires the cache lock only once for the entire batch.
1070
func (s *SQLStore) updateChanCacheBatch(edgesToCache map[uint64]ChannelEdge) {
×
1071
        if len(edgesToCache) == 0 {
×
1072
                return
×
UNCOV
1073
        }
×
1074

UNCOV
1075
        s.cacheMu.Lock()
×
UNCOV
1076
        defer s.cacheMu.Unlock()
×
1077

×
1078
        for chanID, edge := range edgesToCache {
×
1079
                s.chanCache.insert(chanID, edge)
×
1080
        }
×
1081
}
1082

1083
// ChanUpdatesInHorizon returns all the known channel edges which have at least
1084
// one edge that has an update timestamp within the specified horizon.
1085
//
1086
// Iterator Lifecycle:
1087
// 1. Initialize state (edgesSeen map, cache tracking, pagination cursors)
1088
// 2. Query batch of channels with policies in time range
1089
// 3. For each channel: check if seen, check cache, or build from DB
1090
// 4. Yield channels to caller
1091
// 5. Update cache after successful batch
1092
// 6. Repeat with updated pagination cursor until no more results
1093
//
1094
// NOTE: This is part of the V1Store interface.
1095
func (s *SQLStore) ChanUpdatesInHorizon(startTime, endTime time.Time,
UNCOV
1096
        opts ...IteratorOption) iter.Seq2[ChannelEdge, error] {
×
UNCOV
1097

×
UNCOV
1098
        // Apply options.
×
UNCOV
1099
        cfg := defaultIteratorConfig()
×
UNCOV
1100
        for _, opt := range opts {
×
UNCOV
1101
                opt(cfg)
×
UNCOV
1102
        }
×
1103

1104
        return func(yield func(ChannelEdge, error) bool) {
×
1105
                var (
×
1106
                        ctx            = context.TODO()
×
1107
                        edgesSeen      = make(map[uint64]struct{})
×
1108
                        edgesToCache   = make(map[uint64]ChannelEdge)
×
1109
                        hits           int
×
UNCOV
1110
                        total          int
×
1111
                        lastUpdateTime sql.NullInt64
×
1112
                        lastID         sql.NullInt64
×
1113
                        hasMore        = true
×
1114
                )
×
1115

×
1116
                // Each iteration, we'll read a batch amount of channel updates
×
1117
                // (consulting the cache along the way), yield them, then loop
×
1118
                // back to decide if we have any more updates to read out.
×
1119
                for hasMore {
×
1120
                        var batch []ChannelEdge
×
1121

×
1122
                        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(),
×
1123
                                func(db SQLQueries) error {
×
1124
                                        //nolint:ll
×
1125
                                        params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{
×
1126
                                                Version: int16(lnwire.GossipVersion1),
×
1127
                                                StartTime: sqldb.SQLInt64(
×
1128
                                                        startTime.Unix(),
×
1129
                                                ),
×
1130
                                                EndTime: sqldb.SQLInt64(
×
1131
                                                        endTime.Unix(),
×
1132
                                                ),
×
1133
                                                LastUpdateTime: lastUpdateTime,
×
1134
                                                LastID:         lastID,
×
1135
                                                MaxResults: sql.NullInt32{
×
1136
                                                        Int32: int32(
×
1137
                                                                cfg.chanUpdateIterBatchSize,
×
1138
                                                        ),
×
1139
                                                        Valid: true,
×
1140
                                                },
×
1141
                                        }
×
1142
                                        //nolint:ll
×
1143
                                        rows, err := db.GetChannelsByPolicyLastUpdateRange(
×
1144
                                                ctx, params,
×
1145
                                        )
×
1146
                                        if err != nil {
×
1147
                                                return err
×
1148
                                        }
×
1149

1150
                                        //nolint:ll
1151
                                        hasMore = len(rows) == cfg.chanUpdateIterBatchSize
×
1152

×
1153
                                        //nolint:ll
×
1154
                                        for _, row := range rows {
×
1155
                                                lastUpdateTime = sql.NullInt64{
×
UNCOV
1156
                                                        Int64: extractMaxUpdateTime(row),
×
UNCOV
1157
                                                        Valid: true,
×
1158
                                                }
×
1159
                                                lastID = sql.NullInt64{
×
1160
                                                        Int64: row.GraphChannel.ID,
×
1161
                                                        Valid: true,
×
1162
                                                }
×
1163

×
1164
                                                // Skip if we've already
×
1165
                                                // processed this channel.
×
1166
                                                chanIDInt := byteOrder.Uint64(
×
1167
                                                        row.GraphChannel.Scid,
×
1168
                                                )
×
1169
                                                _, ok := edgesSeen[chanIDInt]
×
1170
                                                if ok {
×
1171
                                                        continue
×
1172
                                                }
1173

1174
                                                s.cacheMu.RLock()
×
1175
                                                channel, ok := s.chanCache.get(
×
1176
                                                        chanIDInt,
×
1177
                                                )
×
1178
                                                s.cacheMu.RUnlock()
×
UNCOV
1179
                                                if ok {
×
UNCOV
1180
                                                        hits++
×
1181
                                                        total++
×
1182
                                                        edgesSeen[chanIDInt] = struct{}{}
×
1183
                                                        batch = append(batch, channel)
×
1184

×
1185
                                                        continue
×
1186
                                                }
1187

1188
                                                chanEdge, err := s.buildChannelFromRow(
×
1189
                                                        ctx, db, row,
×
1190
                                                )
×
1191
                                                if err != nil {
×
1192
                                                        return err
×
UNCOV
1193
                                                }
×
1194

1195
                                                edgesSeen[chanIDInt] = struct{}{}
×
1196
                                                edgesToCache[chanIDInt] = chanEdge
×
1197

×
1198
                                                batch = append(batch, chanEdge)
×
1199

×
1200
                                                total++
×
1201
                                        }
1202

1203
                                        return nil
×
1204
                                }, func() {
×
1205
                                        batch = nil
×
1206
                                        edgesSeen = make(map[uint64]struct{})
×
1207
                                        edgesToCache = make(
×
UNCOV
1208
                                                map[uint64]ChannelEdge,
×
UNCOV
1209
                                        )
×
1210
                                })
×
1211

1212
                        if err != nil {
×
1213
                                log.Errorf("ChanUpdatesInHorizon "+
×
1214
                                        "batch error: %v", err)
×
1215

×
1216
                                yield(ChannelEdge{}, err)
×
1217

×
UNCOV
1218
                                return
×
1219
                        }
×
1220

1221
                        for _, edge := range batch {
×
1222
                                if !yield(edge, nil) {
×
1223
                                        return
×
1224
                                }
×
1225
                        }
1226

1227
                        // Update cache after successful batch yield, setting
1228
                        // the cache lock only once for the entire batch.
1229
                        s.updateChanCacheBatch(edgesToCache)
×
1230
                        edgesToCache = make(map[uint64]ChannelEdge)
×
1231

×
UNCOV
1232
                        // If the batch didn't yield anything, then we're done.
×
UNCOV
1233
                        if len(batch) == 0 {
×
UNCOV
1234
                                break
×
1235
                        }
1236
                }
1237

1238
                if total > 0 {
×
1239
                        log.Debugf("ChanUpdatesInHorizon hit percentage: "+
×
1240
                                "%.2f (%d/%d)",
×
1241
                                float64(hits)*100/float64(total), hits, total)
×
UNCOV
1242
                } else {
×
UNCOV
1243
                        log.Debugf("ChanUpdatesInHorizon returned no edges "+
×
UNCOV
1244
                                "in horizon (%s, %s)", startTime, endTime)
×
1245
                }
×
1246
        }
1247
}
1248

1249
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1250
// data to the call-back. If withAddrs is true, then the call-back will also be
1251
// provided with the addresses associated with the node. The address retrieval
1252
// result in an additional round-trip to the database, so it should only be used
1253
// if the addresses are actually needed.
1254
//
1255
// NOTE: part of the V1Store interface.
1256
func (s *SQLStore) ForEachNodeCached(ctx context.Context, withAddrs bool,
1257
        cb func(ctx context.Context, node route.Vertex, addrs []net.Addr,
UNCOV
1258
                chans map[uint64]*DirectedChannel) error, reset func()) error {
×
UNCOV
1259

×
UNCOV
1260
        type nodeCachedBatchData struct {
×
UNCOV
1261
                features      map[int64][]int
×
UNCOV
1262
                addrs         map[int64][]nodeAddress
×
UNCOV
1263
                chanBatchData *batchChannelData
×
UNCOV
1264
                chanMap       map[int64][]sqlc.ListChannelsForNodeIDsRow
×
1265
        }
×
1266

×
1267
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1268
                // pageQueryFunc is used to query the next page of nodes.
×
1269
                pageQueryFunc := func(ctx context.Context, lastID int64,
×
1270
                        limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
1271

×
1272
                        return db.ListNodeIDsAndPubKeys(
×
1273
                                ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
1274
                                        Version: int16(lnwire.GossipVersion1),
×
1275
                                        ID:      lastID,
×
1276
                                        Limit:   limit,
×
1277
                                },
×
1278
                        )
×
1279
                }
×
1280

1281
                // batchDataFunc is then used to batch load the data required
1282
                // for each page of nodes.
1283
                batchDataFunc := func(ctx context.Context,
×
1284
                        nodeIDs []int64) (*nodeCachedBatchData, error) {
×
1285

×
1286
                        // Batch load node features.
×
UNCOV
1287
                        nodeFeatures, err := batchLoadNodeFeaturesHelper(
×
UNCOV
1288
                                ctx, s.cfg.QueryCfg, db, nodeIDs,
×
UNCOV
1289
                        )
×
1290
                        if err != nil {
×
1291
                                return nil, fmt.Errorf("unable to batch load "+
×
1292
                                        "node features: %w", err)
×
1293
                        }
×
1294

1295
                        // Maybe fetch the node's addresses if requested.
1296
                        var nodeAddrs map[int64][]nodeAddress
×
1297
                        if withAddrs {
×
1298
                                nodeAddrs, err = batchLoadNodeAddressesHelper(
×
1299
                                        ctx, s.cfg.QueryCfg, db, nodeIDs,
×
1300
                                )
×
UNCOV
1301
                                if err != nil {
×
UNCOV
1302
                                        return nil, fmt.Errorf("unable to "+
×
1303
                                                "batch load node "+
×
1304
                                                "addresses: %w", err)
×
1305
                                }
×
1306
                        }
1307

1308
                        // Batch load ALL unique channels for ALL nodes in this
1309
                        // page.
1310
                        allChannels, err := db.ListChannelsForNodeIDs(
×
1311
                                ctx, sqlc.ListChannelsForNodeIDsParams{
×
1312
                                        Version:  int16(lnwire.GossipVersion1),
×
UNCOV
1313
                                        Node1Ids: nodeIDs,
×
UNCOV
1314
                                        Node2Ids: nodeIDs,
×
UNCOV
1315
                                },
×
UNCOV
1316
                        )
×
1317
                        if err != nil {
×
1318
                                return nil, fmt.Errorf("unable to batch "+
×
1319
                                        "fetch channels for nodes: %w", err)
×
1320
                        }
×
1321

1322
                        // Deduplicate channels and collect IDs.
1323
                        var (
×
1324
                                allChannelIDs []int64
×
1325
                                allPolicyIDs  []int64
×
1326
                        )
×
1327
                        uniqueChannels := make(
×
UNCOV
1328
                                map[int64]sqlc.ListChannelsForNodeIDsRow,
×
UNCOV
1329
                        )
×
1330

×
1331
                        for _, channel := range allChannels {
×
1332
                                channelID := channel.GraphChannel.ID
×
1333

×
1334
                                // Only process each unique channel once.
×
1335
                                _, exists := uniqueChannels[channelID]
×
1336
                                if exists {
×
1337
                                        continue
×
1338
                                }
1339

1340
                                uniqueChannels[channelID] = channel
×
1341
                                allChannelIDs = append(allChannelIDs, channelID)
×
1342

×
1343
                                if channel.Policy1ID.Valid {
×
1344
                                        allPolicyIDs = append(
×
UNCOV
1345
                                                allPolicyIDs,
×
UNCOV
1346
                                                channel.Policy1ID.Int64,
×
1347
                                        )
×
1348
                                }
×
1349
                                if channel.Policy2ID.Valid {
×
1350
                                        allPolicyIDs = append(
×
1351
                                                allPolicyIDs,
×
1352
                                                channel.Policy2ID.Int64,
×
1353
                                        )
×
1354
                                }
×
1355
                        }
1356

1357
                        // Batch load channel data for all unique channels.
1358
                        channelBatchData, err := batchLoadChannelData(
×
1359
                                ctx, s.cfg.QueryCfg, db, allChannelIDs,
×
1360
                                allPolicyIDs,
×
1361
                        )
×
UNCOV
1362
                        if err != nil {
×
UNCOV
1363
                                return nil, fmt.Errorf("unable to batch "+
×
UNCOV
1364
                                        "load channel data: %w", err)
×
1365
                        }
×
1366

1367
                        // Create map of node ID to channels that involve this
1368
                        // node.
1369
                        nodeIDSet := make(map[int64]bool)
×
1370
                        for _, nodeID := range nodeIDs {
×
1371
                                nodeIDSet[nodeID] = true
×
1372
                        }
×
1373

UNCOV
1374
                        nodeChannelMap := make(
×
UNCOV
1375
                                map[int64][]sqlc.ListChannelsForNodeIDsRow,
×
1376
                        )
×
1377
                        for _, channel := range uniqueChannels {
×
1378
                                // Add channel to both nodes if they're in our
×
1379
                                // current page.
×
UNCOV
1380
                                node1 := channel.GraphChannel.NodeID1
×
1381
                                if nodeIDSet[node1] {
×
1382
                                        nodeChannelMap[node1] = append(
×
1383
                                                nodeChannelMap[node1], channel,
×
1384
                                        )
×
1385
                                }
×
1386
                                node2 := channel.GraphChannel.NodeID2
×
1387
                                if nodeIDSet[node2] {
×
1388
                                        nodeChannelMap[node2] = append(
×
1389
                                                nodeChannelMap[node2], channel,
×
1390
                                        )
×
1391
                                }
×
1392
                        }
1393

1394
                        return &nodeCachedBatchData{
×
1395
                                features:      nodeFeatures,
×
1396
                                addrs:         nodeAddrs,
×
1397
                                chanBatchData: channelBatchData,
×
1398
                                chanMap:       nodeChannelMap,
×
UNCOV
1399
                        }, nil
×
1400
                }
1401

1402
                // processItem is used to process each node in the current page.
1403
                processItem := func(ctx context.Context,
×
1404
                        nodeData sqlc.ListNodeIDsAndPubKeysRow,
×
1405
                        batchData *nodeCachedBatchData) error {
×
1406

×
UNCOV
1407
                        // Build feature vector for this node.
×
UNCOV
1408
                        fv := lnwire.EmptyFeatureVector()
×
UNCOV
1409
                        features, exists := batchData.features[nodeData.ID]
×
1410
                        if exists {
×
1411
                                for _, bit := range features {
×
1412
                                        fv.Set(lnwire.FeatureBit(bit))
×
1413
                                }
×
1414
                        }
1415

1416
                        var nodePub route.Vertex
×
1417
                        copy(nodePub[:], nodeData.PubKey)
×
1418

×
1419
                        nodeChannels := batchData.chanMap[nodeData.ID]
×
1420

×
UNCOV
1421
                        toNodeCallback := func() route.Vertex {
×
UNCOV
1422
                                return nodePub
×
1423
                        }
×
1424

1425
                        // Build cached channels map for this node.
1426
                        channels := make(map[uint64]*DirectedChannel)
×
1427
                        for _, channelRow := range nodeChannels {
×
1428
                                directedChan, err := buildDirectedChannel(
×
1429
                                        s.cfg.ChainHash, nodeData.ID, nodePub,
×
1430
                                        channelRow, batchData.chanBatchData, fv,
×
UNCOV
1431
                                        toNodeCallback,
×
UNCOV
1432
                                )
×
1433
                                if err != nil {
×
1434
                                        return err
×
1435
                                }
×
1436

1437
                                channels[directedChan.ChannelID] = directedChan
×
1438
                        }
1439

1440
                        addrs, err := buildNodeAddresses(
×
1441
                                batchData.addrs[nodeData.ID],
×
1442
                        )
×
UNCOV
1443
                        if err != nil {
×
1444
                                return fmt.Errorf("unable to build node "+
×
UNCOV
1445
                                        "addresses: %w", err)
×
UNCOV
1446
                        }
×
1447

1448
                        return cb(ctx, nodePub, addrs, channels)
×
1449
                }
1450

1451
                return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
1452
                        ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
×
1453
                        func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
UNCOV
1454
                                return node.ID
×
1455
                        },
×
1456
                        func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
UNCOV
1457
                                error) {
×
1458

×
1459
                                return node.ID, nil
×
1460
                        },
×
1461
                        batchDataFunc, processItem,
1462
                )
1463
        }, reset)
1464
}
1465

1466
// ForEachChannelCacheable iterates through all the channel edges stored
1467
// within the graph and invokes the passed callback for each edge. The
1468
// callback takes two edges as since this is a directed graph, both the
1469
// in/out edges are visited. If the callback returns an error, then the
1470
// transaction is aborted and the iteration stops early.
1471
//
1472
// NOTE: If an edge can't be found, or wasn't advertised, then a nil
1473
// pointer for that particular channel edge routing policy will be
1474
// passed into the callback.
1475
//
1476
// NOTE: this method is like ForEachChannel but fetches only the data
1477
// required for the graph cache.
1478
func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1479
        *models.CachedEdgePolicy, *models.CachedEdgePolicy) error,
UNCOV
1480
        reset func()) error {
×
UNCOV
1481

×
UNCOV
1482
        ctx := context.TODO()
×
UNCOV
1483

×
UNCOV
1484
        handleChannel := func(_ context.Context,
×
UNCOV
1485
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
×
UNCOV
1486

×
1487
                node1, node2, err := buildNodeVertices(
×
1488
                        row.Node1Pubkey, row.Node2Pubkey,
×
1489
                )
×
1490
                if err != nil {
×
1491
                        return err
×
1492
                }
×
1493

1494
                edge := buildCacheableChannelInfo(
×
1495
                        row.Scid, row.Capacity.Int64, node1, node2,
×
1496
                )
×
1497

×
1498
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
1499
                if err != nil {
×
UNCOV
1500
                        return err
×
1501
                }
×
1502

1503
                pol1, pol2, err := buildCachedChanPolicies(
×
1504
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
1505
                )
×
1506
                if err != nil {
×
1507
                        return err
×
1508
                }
×
1509

1510
                return cb(edge, pol1, pol2)
×
1511
        }
1512

1513
        extractCursor := func(
×
1514
                row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) int64 {
×
1515

×
UNCOV
1516
                return row.ID
×
1517
        }
×
1518

UNCOV
1519
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1520
                //nolint:ll
×
1521
                queryFunc := func(ctx context.Context, lastID int64,
×
1522
                        limit int32) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow,
×
1523
                        error) {
×
1524

×
UNCOV
1525
                        return db.ListChannelsWithPoliciesForCachePaginated(
×
1526
                                ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
×
1527
                                        Version: int16(lnwire.GossipVersion1),
×
1528
                                        ID:      lastID,
×
1529
                                        Limit:   limit,
×
1530
                                },
×
1531
                        )
×
1532
                }
×
1533

1534
                return sqldb.ExecutePaginatedQuery(
×
1535
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
1536
                        extractCursor, handleChannel,
×
1537
                )
×
1538
        }, reset)
1539
}
1540

1541
// ForEachChannel iterates through all the channel edges stored within the
1542
// graph and invokes the passed callback for each edge. The callback takes two
1543
// edges as since this is a directed graph, both the in/out edges are visited.
1544
// If the callback returns an error, then the transaction is aborted and the
1545
// iteration stops early.
1546
//
1547
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1548
// for that particular channel edge routing policy will be passed into the
1549
// callback.
1550
//
1551
// NOTE: part of the V1Store interface.
1552
func (s *SQLStore) ForEachChannel(ctx context.Context,
1553
        cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
UNCOV
1554
                *models.ChannelEdgePolicy) error, reset func()) error {
×
UNCOV
1555

×
UNCOV
1556
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
1557
                return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
×
UNCOV
1558
        }, reset)
×
1559
}
1560

1561
// FilterChannelRange returns the channel ID's of all known channels which were
1562
// mined in a block height within the passed range. The channel IDs are grouped
1563
// by their common block height. This method can be used to quickly share with a
1564
// peer the set of channels we know of within a particular range to catch them
1565
// up after a period of time offline. If withTimestamps is true then the
1566
// timestamp info of the latest received channel update messages of the channel
1567
// will be included in the response.
1568
//
1569
// NOTE: This is part of the V1Store interface.
1570
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
UNCOV
1571
        withTimestamps bool) ([]BlockChannelRange, error) {
×
UNCOV
1572

×
UNCOV
1573
        var (
×
UNCOV
1574
                ctx       = context.TODO()
×
UNCOV
1575
                startSCID = &lnwire.ShortChannelID{
×
UNCOV
1576
                        BlockHeight: startHeight,
×
UNCOV
1577
                }
×
1578
                endSCID = lnwire.ShortChannelID{
×
1579
                        BlockHeight: endHeight,
×
1580
                        TxIndex:     math.MaxUint32 & 0x00ffffff,
×
1581
                        TxPosition:  math.MaxUint16,
×
1582
                }
×
1583
                chanIDStart = channelIDToBytes(startSCID.ToUint64())
×
1584
                chanIDEnd   = channelIDToBytes(endSCID.ToUint64())
×
1585
        )
×
1586

×
1587
        // 1) get all channels where channelID is between start and end chan ID.
×
1588
        // 2) skip if not public (ie, no channel_proof)
×
1589
        // 3) collect that channel.
×
1590
        // 4) if timestamps are wanted, fetch both policies for node 1 and node2
×
1591
        //    and add those timestamps to the collected channel.
×
1592
        channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
×
1593
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1594
                dbChans, err := db.GetPublicV1ChannelsBySCID(
×
1595
                        ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
×
1596
                                StartScid: chanIDStart,
×
1597
                                EndScid:   chanIDEnd,
×
1598
                        },
×
1599
                )
×
1600
                if err != nil {
×
1601
                        return fmt.Errorf("unable to fetch channel range: %w",
×
1602
                                err)
×
1603
                }
×
1604

1605
                for _, dbChan := range dbChans {
×
1606
                        cid := lnwire.NewShortChanIDFromInt(
×
1607
                                byteOrder.Uint64(dbChan.Scid),
×
1608
                        )
×
1609
                        chanInfo := NewChannelUpdateInfo(
×
1610
                                cid, time.Time{}, time.Time{},
×
UNCOV
1611
                        )
×
1612

×
1613
                        if !withTimestamps {
×
1614
                                channelsPerBlock[cid.BlockHeight] = append(
×
1615
                                        channelsPerBlock[cid.BlockHeight],
×
1616
                                        chanInfo,
×
1617
                                )
×
1618

×
1619
                                continue
×
1620
                        }
1621

1622
                        //nolint:ll
1623
                        node1Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1624
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1625
                                        Version:   int16(lnwire.GossipVersion1),
×
1626
                                        ChannelID: dbChan.ID,
×
UNCOV
1627
                                        NodeID:    dbChan.NodeID1,
×
UNCOV
1628
                                },
×
UNCOV
1629
                        )
×
1630
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1631
                                return fmt.Errorf("unable to fetch node1 "+
×
1632
                                        "policy: %w", err)
×
1633
                        } else if err == nil {
×
1634
                                chanInfo.Node1UpdateTimestamp = time.Unix(
×
1635
                                        node1Policy.LastUpdate.Int64, 0,
×
1636
                                )
×
1637
                        }
×
1638

1639
                        //nolint:ll
1640
                        node2Policy, err := db.GetChannelPolicyByChannelAndNode(
×
1641
                                ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
1642
                                        Version:   int16(lnwire.GossipVersion1),
×
1643
                                        ChannelID: dbChan.ID,
×
1644
                                        NodeID:    dbChan.NodeID2,
×
UNCOV
1645
                                },
×
UNCOV
1646
                        )
×
1647
                        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
1648
                                return fmt.Errorf("unable to fetch node2 "+
×
1649
                                        "policy: %w", err)
×
1650
                        } else if err == nil {
×
1651
                                chanInfo.Node2UpdateTimestamp = time.Unix(
×
1652
                                        node2Policy.LastUpdate.Int64, 0,
×
1653
                                )
×
1654
                        }
×
1655

1656
                        channelsPerBlock[cid.BlockHeight] = append(
×
1657
                                channelsPerBlock[cid.BlockHeight], chanInfo,
×
1658
                        )
×
1659
                }
1660

1661
                return nil
×
UNCOV
1662
        }, func() {
×
1663
                channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
×
1664
        })
×
1665
        if err != nil {
×
UNCOV
1666
                return nil, fmt.Errorf("unable to fetch channel range: %w", err)
×
UNCOV
1667
        }
×
1668

1669
        if len(channelsPerBlock) == 0 {
×
1670
                return nil, nil
×
1671
        }
×
1672

1673
        // Return the channel ranges in ascending block height order.
1674
        blocks := slices.Collect(maps.Keys(channelsPerBlock))
×
UNCOV
1675
        slices.Sort(blocks)
×
1676

×
1677
        return fn.Map(blocks, func(block uint32) BlockChannelRange {
×
1678
                return BlockChannelRange{
×
UNCOV
1679
                        Height:   block,
×
UNCOV
1680
                        Channels: channelsPerBlock[block],
×
1681
                }
×
1682
        }), nil
×
1683
}
1684

1685
// MarkEdgeZombie attempts to mark a channel identified by its channel ID as a
1686
// zombie. This method is used on an ad-hoc basis, when channels need to be
1687
// marked as zombies outside the normal pruning cycle.
1688
//
1689
// NOTE: part of the V1Store interface.
1690
func (s *SQLStore) MarkEdgeZombie(chanID uint64,
UNCOV
1691
        pubKey1, pubKey2 [33]byte) error {
×
UNCOV
1692

×
UNCOV
1693
        ctx := context.TODO()
×
UNCOV
1694

×
UNCOV
1695
        s.cacheMu.Lock()
×
UNCOV
1696
        defer s.cacheMu.Unlock()
×
UNCOV
1697

×
1698
        chanIDB := channelIDToBytes(chanID)
×
1699

×
1700
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1701
                return db.UpsertZombieChannel(
×
1702
                        ctx, sqlc.UpsertZombieChannelParams{
×
1703
                                Version:  int16(lnwire.GossipVersion1),
×
1704
                                Scid:     chanIDB,
×
1705
                                NodeKey1: pubKey1[:],
×
1706
                                NodeKey2: pubKey2[:],
×
1707
                        },
×
1708
                )
×
1709
        }, sqldb.NoOpReset)
×
1710
        if err != nil {
×
1711
                return fmt.Errorf("unable to upsert zombie channel "+
×
1712
                        "(channel_id=%d): %w", chanID, err)
×
1713
        }
×
1714

1715
        s.rejectCache.remove(chanID)
×
1716
        s.chanCache.remove(chanID)
×
1717

×
1718
        return nil
×
1719
}
1720

1721
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
1722
//
1723
// NOTE: part of the V1Store interface.
1724
func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
×
1725
        s.cacheMu.Lock()
×
UNCOV
1726
        defer s.cacheMu.Unlock()
×
UNCOV
1727

×
UNCOV
1728
        var (
×
UNCOV
1729
                ctx     = context.TODO()
×
UNCOV
1730
                chanIDB = channelIDToBytes(chanID)
×
1731
        )
×
1732

×
1733
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1734
                res, err := db.DeleteZombieChannel(
×
1735
                        ctx, sqlc.DeleteZombieChannelParams{
×
1736
                                Scid:    chanIDB,
×
1737
                                Version: int16(lnwire.GossipVersion1),
×
1738
                        },
×
1739
                )
×
1740
                if err != nil {
×
1741
                        return fmt.Errorf("unable to delete zombie channel: %w",
×
1742
                                err)
×
1743
                }
×
1744

1745
                rows, err := res.RowsAffected()
×
1746
                if err != nil {
×
1747
                        return err
×
1748
                }
×
1749

1750
                if rows == 0 {
×
UNCOV
1751
                        return ErrZombieEdgeNotFound
×
1752
                } else if rows > 1 {
×
1753
                        return fmt.Errorf("deleted %d zombie rows, "+
×
1754
                                "expected 1", rows)
×
1755
                }
×
1756

1757
                return nil
×
1758
        }, sqldb.NoOpReset)
1759
        if err != nil {
×
1760
                return fmt.Errorf("unable to mark edge live "+
×
1761
                        "(channel_id=%d): %w", chanID, err)
×
1762
        }
×
1763

1764
        s.rejectCache.remove(chanID)
×
UNCOV
1765
        s.chanCache.remove(chanID)
×
1766

×
1767
        return err
×
1768
}
1769

1770
// IsZombieEdge returns whether the edge is considered zombie. If it is a
1771
// zombie, then the two node public keys corresponding to this edge are also
1772
// returned.
1773
//
1774
// NOTE: part of the V1Store interface.
1775
func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
UNCOV
1776
        error) {
×
UNCOV
1777

×
UNCOV
1778
        var (
×
UNCOV
1779
                ctx              = context.TODO()
×
UNCOV
1780
                isZombie         bool
×
UNCOV
1781
                pubKey1, pubKey2 route.Vertex
×
UNCOV
1782
                chanIDB          = channelIDToBytes(chanID)
×
1783
        )
×
1784

×
1785
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1786
                zombie, err := db.GetZombieChannel(
×
1787
                        ctx, sqlc.GetZombieChannelParams{
×
1788
                                Scid:    chanIDB,
×
1789
                                Version: int16(lnwire.GossipVersion1),
×
1790
                        },
×
1791
                )
×
1792
                if errors.Is(err, sql.ErrNoRows) {
×
1793
                        return nil
×
1794
                }
×
1795
                if err != nil {
×
1796
                        return fmt.Errorf("unable to fetch zombie channel: %w",
×
1797
                                err)
×
1798
                }
×
1799

1800
                copy(pubKey1[:], zombie.NodeKey1)
×
1801
                copy(pubKey2[:], zombie.NodeKey2)
×
1802
                isZombie = true
×
1803

×
1804
                return nil
×
1805
        }, sqldb.NoOpReset)
UNCOV
1806
        if err != nil {
×
1807
                return false, route.Vertex{}, route.Vertex{},
×
1808
                        fmt.Errorf("%w: %w (chanID=%d)",
×
1809
                                ErrCantCheckIfZombieEdgeStr, err, chanID)
×
1810
        }
×
1811

UNCOV
1812
        return isZombie, pubKey1, pubKey2, nil
×
1813
}
1814

1815
// NumZombies returns the current number of zombie channels in the graph.
1816
//
1817
// NOTE: part of the V1Store interface.
UNCOV
1818
func (s *SQLStore) NumZombies() (uint64, error) {
×
1819
        var (
×
UNCOV
1820
                ctx        = context.TODO()
×
UNCOV
1821
                numZombies uint64
×
UNCOV
1822
        )
×
UNCOV
1823
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
UNCOV
1824
                count, err := db.CountZombieChannels(
×
1825
                        ctx, int16(lnwire.GossipVersion1),
×
1826
                )
×
1827
                if err != nil {
×
1828
                        return fmt.Errorf("unable to count zombie channels: %w",
×
1829
                                err)
×
1830
                }
×
1831

1832
                numZombies = uint64(count)
×
1833

×
1834
                return nil
×
1835
        }, sqldb.NoOpReset)
UNCOV
1836
        if err != nil {
×
1837
                return 0, fmt.Errorf("unable to count zombies: %w", err)
×
1838
        }
×
1839

UNCOV
1840
        return numZombies, nil
×
1841
}
1842

1843
// DeleteChannelEdges removes edges with the given channel IDs from the
1844
// database and marks them as zombies. This ensures that we're unable to re-add
1845
// it to our database once again. If an edge does not exist within the
1846
// database, then ErrEdgeNotFound will be returned. If strictZombiePruning is
1847
// true, then when we mark these edges as zombies, we'll set up the keys such
1848
// that we require the node that failed to send the fresh update to be the one
1849
// that resurrects the channel from its zombie state. The markZombie bool
1850
// denotes whether to mark the channel as a zombie.
1851
//
1852
// NOTE: part of the V1Store interface.
1853
func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
UNCOV
1854
        chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) {
×
UNCOV
1855

×
UNCOV
1856
        s.cacheMu.Lock()
×
UNCOV
1857
        defer s.cacheMu.Unlock()
×
UNCOV
1858

×
1859
        // Keep track of which channels we end up finding so that we can
×
1860
        // correctly return ErrEdgeNotFound if we do not find a channel.
×
1861
        chanLookup := make(map[uint64]struct{}, len(chanIDs))
×
1862
        for _, chanID := range chanIDs {
×
1863
                chanLookup[chanID] = struct{}{}
×
1864
        }
×
1865

1866
        var (
×
1867
                ctx   = context.TODO()
×
1868
                edges []*models.ChannelEdgeInfo
×
1869
        )
×
UNCOV
1870
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
1871
                // First, collect all channel rows.
×
1872
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
1873
                chanCallBack := func(ctx context.Context,
×
1874
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
1875

×
1876
                        // Deleting the entry from the map indicates that we
×
1877
                        // have found the channel.
×
1878
                        scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1879
                        delete(chanLookup, scid)
×
1880

×
1881
                        channelRows = append(channelRows, row)
×
1882

×
1883
                        return nil
×
1884
                }
×
1885

1886
                err := s.forEachChanWithPoliciesInSCIDList(
×
1887
                        ctx, db, chanCallBack, chanIDs,
×
1888
                )
×
1889
                if err != nil {
×
UNCOV
1890
                        return err
×
1891
                }
×
1892

1893
                if len(chanLookup) > 0 {
×
1894
                        return ErrEdgeNotFound
×
1895
                }
×
1896

UNCOV
1897
                if len(channelRows) == 0 {
×
1898
                        return nil
×
1899
                }
×
1900

1901
                // Batch build all channel edges.
1902
                var chanIDsToDelete []int64
×
1903
                edges, chanIDsToDelete, err = batchBuildChannelInfo(
×
1904
                        ctx, s.cfg, db, channelRows,
×
UNCOV
1905
                )
×
UNCOV
1906
                if err != nil {
×
1907
                        return err
×
1908
                }
×
1909

1910
                if markZombie {
×
1911
                        for i, row := range channelRows {
×
1912
                                scid := byteOrder.Uint64(row.GraphChannel.Scid)
×
1913

×
UNCOV
1914
                                err := handleZombieMarking(
×
1915
                                        ctx, db, row, edges[i],
×
1916
                                        strictZombiePruning, scid,
×
1917
                                )
×
1918
                                if err != nil {
×
1919
                                        return fmt.Errorf("unable to mark "+
×
1920
                                                "channel as zombie: %w", err)
×
1921
                                }
×
1922
                        }
1923
                }
1924

1925
                return s.deleteChannels(ctx, db, chanIDsToDelete)
×
1926
        }, func() {
×
UNCOV
1927
                edges = nil
×
UNCOV
1928

×
UNCOV
1929
                // Re-fill the lookup map.
×
1930
                for _, chanID := range chanIDs {
×
1931
                        chanLookup[chanID] = struct{}{}
×
1932
                }
×
1933
        })
1934
        if err != nil {
×
1935
                return nil, fmt.Errorf("unable to delete channel edges: %w",
×
1936
                        err)
×
1937
        }
×
1938

1939
        for _, chanID := range chanIDs {
×
1940
                s.rejectCache.remove(chanID)
×
1941
                s.chanCache.remove(chanID)
×
1942
        }
×
1943

1944
        return edges, nil
×
1945
}
1946

1947
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1948
// channel identified by the channel ID. If the channel can't be found, then
1949
// ErrEdgeNotFound is returned. A struct which houses the general information
1950
// for the channel itself is returned as well as two structs that contain the
1951
// routing policies for the channel in either direction.
1952
//
1953
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1954
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1955
// the ChannelEdgeInfo will only include the public keys of each node.
1956
//
1957
// NOTE: part of the V1Store interface.
1958
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1959
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
UNCOV
1960
        *models.ChannelEdgePolicy, error) {
×
UNCOV
1961

×
UNCOV
1962
        var (
×
UNCOV
1963
                ctx              = context.TODO()
×
UNCOV
1964
                edge             *models.ChannelEdgeInfo
×
1965
                policy1, policy2 *models.ChannelEdgePolicy
×
1966
                chanIDB          = channelIDToBytes(chanID)
×
1967
        )
×
1968
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
1969
                row, err := db.GetChannelBySCIDWithPolicies(
×
1970
                        ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
×
1971
                                Scid:    chanIDB,
×
1972
                                Version: int16(lnwire.GossipVersion1),
×
1973
                        },
×
1974
                )
×
1975
                if errors.Is(err, sql.ErrNoRows) {
×
1976
                        // First check if this edge is perhaps in the zombie
×
1977
                        // index.
×
1978
                        zombie, err := db.GetZombieChannel(
×
1979
                                ctx, sqlc.GetZombieChannelParams{
×
1980
                                        Scid:    chanIDB,
×
1981
                                        Version: int16(lnwire.GossipVersion1),
×
1982
                                },
×
1983
                        )
×
1984
                        if errors.Is(err, sql.ErrNoRows) {
×
1985
                                return ErrEdgeNotFound
×
1986
                        } else if err != nil {
×
1987
                                return fmt.Errorf("unable to check if "+
×
1988
                                        "channel is zombie: %w", err)
×
1989
                        }
×
1990

1991
                        // At this point, we know the channel is a zombie, so
1992
                        // we'll return an error indicating this, and we will
1993
                        // populate the edge info with the public keys of each
1994
                        // party as this is the only information we have about
1995
                        // it.
UNCOV
1996
                        edge = &models.ChannelEdgeInfo{}
×
UNCOV
1997
                        copy(edge.NodeKey1Bytes[:], zombie.NodeKey1)
×
UNCOV
1998
                        copy(edge.NodeKey2Bytes[:], zombie.NodeKey2)
×
UNCOV
1999

×
UNCOV
2000
                        return ErrZombieEdge
×
2001
                } else if err != nil {
×
2002
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2003
                }
×
2004

2005
                node1, node2, err := buildNodeVertices(
×
2006
                        row.GraphNode.PubKey, row.GraphNode_2.PubKey,
×
2007
                )
×
2008
                if err != nil {
×
UNCOV
2009
                        return err
×
2010
                }
×
2011

2012
                edge, err = getAndBuildEdgeInfo(
×
2013
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2014
                )
×
2015
                if err != nil {
×
UNCOV
2016
                        return fmt.Errorf("unable to build channel info: %w",
×
2017
                                err)
×
2018
                }
×
2019

2020
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2021
                if err != nil {
×
2022
                        return fmt.Errorf("unable to extract channel "+
×
2023
                                "policies: %w", err)
×
UNCOV
2024
                }
×
2025

2026
                policy1, policy2, err = getAndBuildChanPolicies(
×
2027
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2028
                        node1, node2,
×
2029
                )
×
UNCOV
2030
                if err != nil {
×
2031
                        return fmt.Errorf("unable to build channel "+
×
2032
                                "policies: %w", err)
×
2033
                }
×
2034

2035
                return nil
×
2036
        }, sqldb.NoOpReset)
2037
        if err != nil {
×
2038
                // If we are returning the ErrZombieEdge, then we also need to
×
UNCOV
2039
                // return the edge info as the method comment indicates that
×
2040
                // this will be populated when the edge is a zombie.
×
UNCOV
2041
                return edge, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
2042
                        err)
×
2043
        }
×
2044

2045
        return edge, policy1, policy2, nil
×
2046
}
2047

2048
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
2049
// the channel identified by the funding outpoint. If the channel can't be
2050
// found, then ErrEdgeNotFound is returned. A struct which houses the general
2051
// information for the channel itself is returned as well as two structs that
2052
// contain the routing policies for the channel in either direction.
2053
//
2054
// NOTE: part of the V1Store interface.
2055
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
2056
        *models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
UNCOV
2057
        *models.ChannelEdgePolicy, error) {
×
UNCOV
2058

×
UNCOV
2059
        var (
×
UNCOV
2060
                ctx              = context.TODO()
×
UNCOV
2061
                edge             *models.ChannelEdgeInfo
×
2062
                policy1, policy2 *models.ChannelEdgePolicy
×
2063
        )
×
2064
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2065
                row, err := db.GetChannelByOutpointWithPolicies(
×
2066
                        ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
×
2067
                                Outpoint: op.String(),
×
2068
                                Version:  int16(lnwire.GossipVersion1),
×
2069
                        },
×
2070
                )
×
2071
                if errors.Is(err, sql.ErrNoRows) {
×
2072
                        return ErrEdgeNotFound
×
2073
                } else if err != nil {
×
2074
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2075
                }
×
2076

2077
                node1, node2, err := buildNodeVertices(
×
2078
                        row.Node1Pubkey, row.Node2Pubkey,
×
2079
                )
×
2080
                if err != nil {
×
UNCOV
2081
                        return err
×
2082
                }
×
2083

2084
                edge, err = getAndBuildEdgeInfo(
×
2085
                        ctx, s.cfg, db, row.GraphChannel, node1, node2,
×
2086
                )
×
2087
                if err != nil {
×
UNCOV
2088
                        return fmt.Errorf("unable to build channel info: %w",
×
2089
                                err)
×
2090
                }
×
2091

2092
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
2093
                if err != nil {
×
2094
                        return fmt.Errorf("unable to extract channel "+
×
2095
                                "policies: %w", err)
×
UNCOV
2096
                }
×
2097

2098
                policy1, policy2, err = getAndBuildChanPolicies(
×
2099
                        ctx, s.cfg.QueryCfg, db, dbPol1, dbPol2, edge.ChannelID,
×
2100
                        node1, node2,
×
2101
                )
×
UNCOV
2102
                if err != nil {
×
2103
                        return fmt.Errorf("unable to build channel "+
×
2104
                                "policies: %w", err)
×
2105
                }
×
2106

2107
                return nil
×
2108
        }, sqldb.NoOpReset)
2109
        if err != nil {
×
2110
                return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
×
UNCOV
2111
                        err)
×
2112
        }
×
2113

2114
        return edge, policy1, policy2, nil
×
2115
}
2116

2117
// HasChannelEdge returns true if the database knows of a channel edge with the
2118
// passed channel ID, and false otherwise. If an edge with that ID is found
2119
// within the graph, then two time stamps representing the last time the edge
2120
// was updated for both directed edges are returned along with the boolean. If
2121
// it is not found, then the zombie index is checked and its result is returned
2122
// as the second boolean.
2123
//
2124
// NOTE: part of the V1Store interface.
2125
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
UNCOV
2126
        bool, error) {
×
UNCOV
2127

×
UNCOV
2128
        ctx := context.TODO()
×
UNCOV
2129

×
UNCOV
2130
        var (
×
2131
                exists          bool
×
2132
                isZombie        bool
×
2133
                node1LastUpdate time.Time
×
2134
                node2LastUpdate time.Time
×
2135
        )
×
2136

×
2137
        // We'll query the cache with the shared lock held to allow multiple
×
2138
        // readers to access values in the cache concurrently if they exist.
×
2139
        s.cacheMu.RLock()
×
2140
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2141
                s.cacheMu.RUnlock()
×
2142
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2143
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2144
                exists, isZombie = entry.flags.unpack()
×
2145

×
2146
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2147
        }
×
2148
        s.cacheMu.RUnlock()
×
2149

×
2150
        s.cacheMu.Lock()
×
2151
        defer s.cacheMu.Unlock()
×
2152

×
2153
        // The item was not found with the shared lock, so we'll acquire the
×
2154
        // exclusive lock and check the cache again in case another method added
×
2155
        // the entry to the cache while no lock was held.
×
2156
        if entry, ok := s.rejectCache.get(chanID); ok {
×
2157
                node1LastUpdate = time.Unix(entry.upd1Time, 0)
×
2158
                node2LastUpdate = time.Unix(entry.upd2Time, 0)
×
2159
                exists, isZombie = entry.flags.unpack()
×
2160

×
2161
                return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2162
        }
×
2163

2164
        chanIDB := channelIDToBytes(chanID)
×
2165
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2166
                channel, err := db.GetChannelBySCID(
×
2167
                        ctx, sqlc.GetChannelBySCIDParams{
×
UNCOV
2168
                                Scid:    chanIDB,
×
2169
                                Version: int16(lnwire.GossipVersion1),
×
2170
                        },
×
2171
                )
×
2172
                if errors.Is(err, sql.ErrNoRows) {
×
2173
                        // Check if it is a zombie channel.
×
2174
                        isZombie, err = db.IsZombieChannel(
×
2175
                                ctx, sqlc.IsZombieChannelParams{
×
2176
                                        Scid:    chanIDB,
×
2177
                                        Version: int16(lnwire.GossipVersion1),
×
2178
                                },
×
2179
                        )
×
2180
                        if err != nil {
×
2181
                                return fmt.Errorf("could not check if channel "+
×
2182
                                        "is zombie: %w", err)
×
2183
                        }
×
2184

2185
                        return nil
×
2186
                } else if err != nil {
×
2187
                        return fmt.Errorf("unable to fetch channel: %w", err)
×
2188
                }
×
2189

2190
                exists = true
×
2191

×
2192
                policy1, err := db.GetChannelPolicyByChannelAndNode(
×
2193
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
UNCOV
2194
                                Version:   int16(lnwire.GossipVersion1),
×
2195
                                ChannelID: channel.ID,
×
2196
                                NodeID:    channel.NodeID1,
×
2197
                        },
×
2198
                )
×
2199
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2200
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2201
                                err)
×
2202
                } else if err == nil {
×
2203
                        node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
×
2204
                }
×
2205

2206
                policy2, err := db.GetChannelPolicyByChannelAndNode(
×
2207
                        ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
×
2208
                                Version:   int16(lnwire.GossipVersion1),
×
2209
                                ChannelID: channel.ID,
×
UNCOV
2210
                                NodeID:    channel.NodeID2,
×
2211
                        },
×
2212
                )
×
2213
                if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
2214
                        return fmt.Errorf("unable to fetch channel policy: %w",
×
2215
                                err)
×
2216
                } else if err == nil {
×
2217
                        node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
×
2218
                }
×
2219

2220
                return nil
×
2221
        }, sqldb.NoOpReset)
2222
        if err != nil {
×
2223
                return time.Time{}, time.Time{}, false, false,
×
UNCOV
2224
                        fmt.Errorf("unable to fetch channel: %w", err)
×
2225
        }
×
2226

2227
        s.rejectCache.insert(chanID, rejectCacheEntry{
×
2228
                upd1Time: node1LastUpdate.Unix(),
×
2229
                upd2Time: node2LastUpdate.Unix(),
×
2230
                flags:    packRejectFlags(exists, isZombie),
×
UNCOV
2231
        })
×
2232

×
2233
        return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
×
2234
}
2235

2236
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
2237
// passed channel point (outpoint). If the passed channel doesn't exist within
2238
// the database, then ErrEdgeNotFound is returned.
2239
//
2240
// NOTE: part of the V1Store interface.
UNCOV
2241
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
×
UNCOV
2242
        var (
×
UNCOV
2243
                ctx       = context.TODO()
×
UNCOV
2244
                channelID uint64
×
UNCOV
2245
        )
×
2246
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2247
                chanID, err := db.GetSCIDByOutpoint(
×
2248
                        ctx, sqlc.GetSCIDByOutpointParams{
×
2249
                                Outpoint: chanPoint.String(),
×
2250
                                Version:  int16(lnwire.GossipVersion1),
×
2251
                        },
×
2252
                )
×
2253
                if errors.Is(err, sql.ErrNoRows) {
×
2254
                        return ErrEdgeNotFound
×
2255
                } else if err != nil {
×
2256
                        return fmt.Errorf("unable to fetch channel ID: %w",
×
2257
                                err)
×
2258
                }
×
2259

2260
                channelID = byteOrder.Uint64(chanID)
×
2261

×
2262
                return nil
×
2263
        }, sqldb.NoOpReset)
UNCOV
2264
        if err != nil {
×
2265
                return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
×
2266
        }
×
2267

UNCOV
2268
        return channelID, nil
×
2269
}
2270

2271
// IsPublicNode is a helper method that determines whether the node with the
2272
// given public key is seen as a public node in the graph from the graph's
2273
// source node's point of view.
2274
//
2275
// NOTE: part of the V1Store interface.
UNCOV
2276
func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
×
UNCOV
2277
        ctx := context.TODO()
×
UNCOV
2278

×
UNCOV
2279
        var isPublic bool
×
UNCOV
2280
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2281
                var err error
×
2282
                isPublic, err = db.IsPublicV1Node(ctx, pubKey[:])
×
2283

×
2284
                return err
×
2285
        }, sqldb.NoOpReset)
×
2286
        if err != nil {
×
2287
                return false, fmt.Errorf("unable to check if node is "+
×
2288
                        "public: %w", err)
×
2289
        }
×
2290

2291
        return isPublic, nil
×
2292
}
2293

2294
// FetchChanInfos returns the set of channel edges that correspond to the passed
2295
// channel ID's. If an edge is the query is unknown to the database, it will
2296
// skipped and the result will contain only those edges that exist at the time
2297
// of the query. This can be used to respond to peer queries that are seeking to
2298
// fill in gaps in their view of the channel graph.
2299
//
2300
// NOTE: part of the V1Store interface.
UNCOV
2301
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
×
UNCOV
2302
        var (
×
UNCOV
2303
                ctx   = context.TODO()
×
UNCOV
2304
                edges = make(map[uint64]ChannelEdge)
×
UNCOV
2305
        )
×
2306
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2307
                // First, collect all channel rows.
×
2308
                var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow
×
2309
                chanCallBack := func(ctx context.Context,
×
2310
                        row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
×
2311

×
2312
                        channelRows = append(channelRows, row)
×
2313
                        return nil
×
2314
                }
×
2315

2316
                err := s.forEachChanWithPoliciesInSCIDList(
×
2317
                        ctx, db, chanCallBack, chanIDs,
×
2318
                )
×
2319
                if err != nil {
×
UNCOV
2320
                        return err
×
2321
                }
×
2322

2323
                if len(channelRows) == 0 {
×
2324
                        return nil
×
2325
                }
×
2326

2327
                // Batch build all channel edges.
2328
                chans, err := batchBuildChannelEdges(
×
2329
                        ctx, s.cfg, db, channelRows,
×
2330
                )
×
UNCOV
2331
                if err != nil {
×
UNCOV
2332
                        return fmt.Errorf("unable to build channel edges: %w",
×
2333
                                err)
×
2334
                }
×
2335

2336
                for _, c := range chans {
×
2337
                        edges[c.Info.ChannelID] = c
×
2338
                }
×
2339

UNCOV
2340
                return err
×
2341
        }, func() {
×
2342
                clear(edges)
×
2343
        })
×
UNCOV
2344
        if err != nil {
×
2345
                return nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2346
        }
×
2347

2348
        res := make([]ChannelEdge, 0, len(edges))
×
2349
        for _, chanID := range chanIDs {
×
2350
                edge, ok := edges[chanID]
×
2351
                if !ok {
×
UNCOV
2352
                        continue
×
2353
                }
2354

2355
                res = append(res, edge)
×
2356
        }
2357

UNCOV
2358
        return res, nil
×
2359
}
2360

2361
// forEachChanWithPoliciesInSCIDList is a wrapper around the
2362
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
2363
// channels in a paginated manner.
2364
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
2365
        db SQLQueries, cb func(ctx context.Context,
2366
                row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
UNCOV
2367
        chanIDs []uint64) error {
×
UNCOV
2368

×
UNCOV
2369
        queryWrapper := func(ctx context.Context,
×
UNCOV
2370
                scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
×
UNCOV
2371
                error) {
×
2372

×
2373
                return db.GetChannelsBySCIDWithPolicies(
×
2374
                        ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
×
2375
                                Version: int16(lnwire.GossipVersion1),
×
2376
                                Scids:   scids,
×
2377
                        },
×
2378
                )
×
2379
        }
×
2380

2381
        return sqldb.ExecuteBatchQuery(
×
2382
                ctx, s.cfg.QueryCfg, chanIDs, channelIDToBytes, queryWrapper,
×
2383
                cb,
×
2384
        )
×
2385
}
2386

2387
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
2388
// ID's that we don't know and are not known zombies of the passed set. In other
2389
// words, we perform a set difference of our set of chan ID's and the ones
2390
// passed in. This method can be used by callers to determine the set of
2391
// channels another peer knows of that we don't. The ChannelUpdateInfos for the
2392
// known zombies is also returned.
2393
//
2394
// NOTE: part of the V1Store interface.
2395
func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
UNCOV
2396
        []ChannelUpdateInfo, error) {
×
UNCOV
2397

×
UNCOV
2398
        var (
×
UNCOV
2399
                ctx          = context.TODO()
×
UNCOV
2400
                newChanIDs   []uint64
×
2401
                knownZombies []ChannelUpdateInfo
×
2402
                infoLookup   = make(
×
2403
                        map[uint64]ChannelUpdateInfo, len(chansInfo),
×
2404
                )
×
2405
        )
×
2406

×
2407
        // We first build a lookup map of the channel ID's to the
×
2408
        // ChannelUpdateInfo. This allows us to quickly delete channels that we
×
2409
        // already know about.
×
2410
        for _, chanInfo := range chansInfo {
×
2411
                infoLookup[chanInfo.ShortChannelID.ToUint64()] = chanInfo
×
2412
        }
×
2413

2414
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2415
                // The call-back function deletes known channels from
×
2416
                // infoLookup, so that we can later check which channels are
×
2417
                // zombies by only looking at the remaining channels in the set.
×
UNCOV
2418
                cb := func(ctx context.Context,
×
2419
                        channel sqlc.GraphChannel) error {
×
2420

×
2421
                        delete(infoLookup, byteOrder.Uint64(channel.Scid))
×
2422

×
2423
                        return nil
×
2424
                }
×
2425

2426
                err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo)
×
2427
                if err != nil {
×
2428
                        return fmt.Errorf("unable to iterate through "+
×
2429
                                "channels: %w", err)
×
UNCOV
2430
                }
×
2431

2432
                // We want to ensure that we deal with the channels in the
2433
                // same order that they were passed in, so we iterate over the
2434
                // original chansInfo slice and then check if that channel is
2435
                // still in the infoLookup map.
UNCOV
2436
                for _, chanInfo := range chansInfo {
×
UNCOV
2437
                        channelID := chanInfo.ShortChannelID.ToUint64()
×
UNCOV
2438
                        if _, ok := infoLookup[channelID]; !ok {
×
UNCOV
2439
                                continue
×
2440
                        }
2441

2442
                        isZombie, err := db.IsZombieChannel(
×
2443
                                ctx, sqlc.IsZombieChannelParams{
×
2444
                                        Scid:    channelIDToBytes(channelID),
×
UNCOV
2445
                                        Version: int16(lnwire.GossipVersion1),
×
UNCOV
2446
                                },
×
2447
                        )
×
2448
                        if err != nil {
×
2449
                                return fmt.Errorf("unable to fetch zombie "+
×
2450
                                        "channel: %w", err)
×
2451
                        }
×
2452

2453
                        if isZombie {
×
2454
                                knownZombies = append(knownZombies, chanInfo)
×
2455

×
2456
                                continue
×
2457
                        }
2458

2459
                        newChanIDs = append(newChanIDs, channelID)
×
2460
                }
2461

UNCOV
2462
                return nil
×
UNCOV
2463
        }, func() {
×
2464
                newChanIDs = nil
×
UNCOV
2465
                knownZombies = nil
×
UNCOV
2466
                // Rebuild the infoLookup map in case of a rollback.
×
2467
                for _, chanInfo := range chansInfo {
×
2468
                        scid := chanInfo.ShortChannelID.ToUint64()
×
2469
                        infoLookup[scid] = chanInfo
×
2470
                }
×
2471
        })
2472
        if err != nil {
×
2473
                return nil, nil, fmt.Errorf("unable to fetch channels: %w", err)
×
2474
        }
×
2475

UNCOV
2476
        return newChanIDs, knownZombies, nil
×
2477
}
2478

2479
// forEachChanInSCIDList is a helper method that executes a paged query
2480
// against the database to fetch all channels that match the passed
2481
// ChannelUpdateInfo slice. The callback function is called for each channel
2482
// that is found.
2483
func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries,
2484
        cb func(ctx context.Context, channel sqlc.GraphChannel) error,
UNCOV
2485
        chansInfo []ChannelUpdateInfo) error {
×
UNCOV
2486

×
UNCOV
2487
        queryWrapper := func(ctx context.Context,
×
UNCOV
2488
                scids [][]byte) ([]sqlc.GraphChannel, error) {
×
UNCOV
2489

×
2490
                return db.GetChannelsBySCIDs(
×
2491
                        ctx, sqlc.GetChannelsBySCIDsParams{
×
2492
                                Version: int16(lnwire.GossipVersion1),
×
2493
                                Scids:   scids,
×
2494
                        },
×
2495
                )
×
2496
        }
×
2497

2498
        chanIDConverter := func(chanInfo ChannelUpdateInfo) []byte {
×
2499
                channelID := chanInfo.ShortChannelID.ToUint64()
×
2500

×
2501
                return channelIDToBytes(channelID)
×
UNCOV
2502
        }
×
2503

2504
        return sqldb.ExecuteBatchQuery(
×
2505
                ctx, s.cfg.QueryCfg, chansInfo, chanIDConverter, queryWrapper,
×
2506
                cb,
×
2507
        )
×
2508
}
2509

2510
// PruneGraphNodes is a garbage collection method which attempts to prune out
2511
// any nodes from the channel graph that are currently unconnected. This ensure
2512
// that we only maintain a graph of reachable nodes. In the event that a pruned
2513
// node gains more channels, it will be re-added back to the graph.
2514
//
2515
// NOTE: this prunes nodes across protocol versions. It will never prune the
2516
// source nodes.
2517
//
2518
// NOTE: part of the V1Store interface.
UNCOV
2519
func (s *SQLStore) PruneGraphNodes() ([]route.Vertex, error) {
×
UNCOV
2520
        var ctx = context.TODO()
×
UNCOV
2521

×
UNCOV
2522
        var prunedNodes []route.Vertex
×
UNCOV
2523
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2524
                var err error
×
2525
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2526

×
2527
                return err
×
2528
        }, func() {
×
2529
                prunedNodes = nil
×
2530
        })
×
2531
        if err != nil {
×
2532
                return nil, fmt.Errorf("unable to prune nodes: %w", err)
×
2533
        }
×
2534

2535
        return prunedNodes, nil
×
2536
}
2537

2538
// PruneGraph prunes newly closed channels from the channel graph in response
2539
// to a new block being solved on the network. Any transactions which spend the
2540
// funding output of any known channels within he graph will be deleted.
2541
// Additionally, the "prune tip", or the last block which has been used to
2542
// prune the graph is stored so callers can ensure the graph is fully in sync
2543
// with the current UTXO state. A slice of channels that have been closed by
2544
// the target block along with any pruned nodes are returned if the function
2545
// succeeds without error.
2546
//
2547
// NOTE: part of the V1Store interface.
2548
func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint,
2549
        blockHash *chainhash.Hash, blockHeight uint32) (
UNCOV
2550
        []*models.ChannelEdgeInfo, []route.Vertex, error) {
×
UNCOV
2551

×
UNCOV
2552
        ctx := context.TODO()
×
UNCOV
2553

×
UNCOV
2554
        s.cacheMu.Lock()
×
2555
        defer s.cacheMu.Unlock()
×
2556

×
2557
        var (
×
2558
                closedChans []*models.ChannelEdgeInfo
×
2559
                prunedNodes []route.Vertex
×
2560
        )
×
2561
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2562
                // First, collect all channel rows that need to be pruned.
×
2563
                var channelRows []sqlc.GetChannelsByOutpointsRow
×
2564
                channelCallback := func(ctx context.Context,
×
2565
                        row sqlc.GetChannelsByOutpointsRow) error {
×
2566

×
2567
                        channelRows = append(channelRows, row)
×
2568

×
2569
                        return nil
×
2570
                }
×
2571

2572
                err := s.forEachChanInOutpoints(
×
2573
                        ctx, db, spentOutputs, channelCallback,
×
2574
                )
×
2575
                if err != nil {
×
UNCOV
2576
                        return fmt.Errorf("unable to fetch channels by "+
×
2577
                                "outpoints: %w", err)
×
2578
                }
×
2579

2580
                if len(channelRows) == 0 {
×
2581
                        // There are no channels to prune. So we can exit early
×
2582
                        // after updating the prune log.
×
2583
                        err = db.UpsertPruneLogEntry(
×
UNCOV
2584
                                ctx, sqlc.UpsertPruneLogEntryParams{
×
2585
                                        BlockHash:   blockHash[:],
×
2586
                                        BlockHeight: int64(blockHeight),
×
2587
                                },
×
2588
                        )
×
2589
                        if err != nil {
×
2590
                                return fmt.Errorf("unable to insert prune log "+
×
2591
                                        "entry: %w", err)
×
2592
                        }
×
2593

2594
                        return nil
×
2595
                }
2596

2597
                // Batch build all channel edges for pruning.
UNCOV
2598
                var chansToDelete []int64
×
2599
                closedChans, chansToDelete, err = batchBuildChannelInfo(
×
UNCOV
2600
                        ctx, s.cfg, db, channelRows,
×
UNCOV
2601
                )
×
UNCOV
2602
                if err != nil {
×
2603
                        return err
×
2604
                }
×
2605

2606
                err = s.deleteChannels(ctx, db, chansToDelete)
×
2607
                if err != nil {
×
2608
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2609
                }
×
2610

2611
                err = db.UpsertPruneLogEntry(
×
2612
                        ctx, sqlc.UpsertPruneLogEntryParams{
×
2613
                                BlockHash:   blockHash[:],
×
2614
                                BlockHeight: int64(blockHeight),
×
UNCOV
2615
                        },
×
2616
                )
×
2617
                if err != nil {
×
2618
                        return fmt.Errorf("unable to insert prune log "+
×
2619
                                "entry: %w", err)
×
2620
                }
×
2621

2622
                // Now that we've pruned some channels, we'll also prune any
2623
                // nodes that no longer have any channels.
2624
                prunedNodes, err = s.pruneGraphNodes(ctx, db)
×
2625
                if err != nil {
×
UNCOV
2626
                        return fmt.Errorf("unable to prune graph nodes: %w",
×
UNCOV
2627
                                err)
×
UNCOV
2628
                }
×
2629

2630
                return nil
×
2631
        }, func() {
×
2632
                prunedNodes = nil
×
2633
                closedChans = nil
×
UNCOV
2634
        })
×
2635
        if err != nil {
×
2636
                return nil, nil, fmt.Errorf("unable to prune graph: %w", err)
×
2637
        }
×
2638

2639
        for _, channel := range closedChans {
×
2640
                s.rejectCache.remove(channel.ChannelID)
×
2641
                s.chanCache.remove(channel.ChannelID)
×
2642
        }
×
2643

2644
        return closedChans, prunedNodes, nil
×
2645
}
2646

2647
// forEachChanInOutpoints is a helper function that executes a paginated
2648
// query to fetch channels by their outpoints and applies the given call-back
2649
// to each.
2650
//
2651
// NOTE: this fetches channels for all protocol versions.
2652
func (s *SQLStore) forEachChanInOutpoints(ctx context.Context, db SQLQueries,
2653
        outpoints []*wire.OutPoint, cb func(ctx context.Context,
UNCOV
2654
                row sqlc.GetChannelsByOutpointsRow) error) error {
×
UNCOV
2655

×
UNCOV
2656
        // Create a wrapper that uses the transaction's db instance to execute
×
UNCOV
2657
        // the query.
×
UNCOV
2658
        queryWrapper := func(ctx context.Context,
×
2659
                pageOutpoints []string) ([]sqlc.GetChannelsByOutpointsRow,
×
2660
                error) {
×
2661

×
2662
                return db.GetChannelsByOutpoints(ctx, pageOutpoints)
×
2663
        }
×
2664

2665
        // Define the conversion function from Outpoint to string.
2666
        outpointToString := func(outpoint *wire.OutPoint) string {
×
2667
                return outpoint.String()
×
2668
        }
×
2669

UNCOV
2670
        return sqldb.ExecuteBatchQuery(
×
2671
                ctx, s.cfg.QueryCfg, outpoints, outpointToString,
×
2672
                queryWrapper, cb,
×
2673
        )
×
2674
}
2675

2676
func (s *SQLStore) deleteChannels(ctx context.Context, db SQLQueries,
2677
        dbIDs []int64) error {
×
2678

×
UNCOV
2679
        // Create a wrapper that uses the transaction's db instance to execute
×
UNCOV
2680
        // the query.
×
UNCOV
2681
        queryWrapper := func(ctx context.Context, ids []int64) ([]any, error) {
×
2682
                return nil, db.DeleteChannels(ctx, ids)
×
2683
        }
×
2684

2685
        idConverter := func(id int64) int64 {
×
2686
                return id
×
2687
        }
×
2688

UNCOV
2689
        return sqldb.ExecuteBatchQuery(
×
2690
                ctx, s.cfg.QueryCfg, dbIDs, idConverter,
×
2691
                queryWrapper, func(ctx context.Context, _ any) error {
×
2692
                        return nil
×
UNCOV
2693
                },
×
2694
        )
2695
}
2696

2697
// ChannelView returns the verifiable edge information for each active channel
2698
// within the known channel graph. The set of UTXOs (along with their scripts)
2699
// returned are the ones that need to be watched on chain to detect channel
2700
// closes on the resident blockchain.
2701
//
2702
// NOTE: part of the V1Store interface.
UNCOV
2703
func (s *SQLStore) ChannelView() ([]EdgePoint, error) {
×
UNCOV
2704
        var (
×
UNCOV
2705
                ctx        = context.TODO()
×
UNCOV
2706
                edgePoints []EdgePoint
×
UNCOV
2707
        )
×
2708

×
2709
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2710
                handleChannel := func(_ context.Context,
×
2711
                        channel sqlc.ListChannelsPaginatedRow) error {
×
2712

×
2713
                        pkScript, err := genMultiSigP2WSH(
×
2714
                                channel.BitcoinKey1, channel.BitcoinKey2,
×
2715
                        )
×
2716
                        if err != nil {
×
2717
                                return err
×
2718
                        }
×
2719

2720
                        op, err := wire.NewOutPointFromString(channel.Outpoint)
×
2721
                        if err != nil {
×
2722
                                return err
×
2723
                        }
×
2724

2725
                        edgePoints = append(edgePoints, EdgePoint{
×
2726
                                FundingPkScript: pkScript,
×
2727
                                OutPoint:        *op,
×
2728
                        })
×
UNCOV
2729

×
2730
                        return nil
×
2731
                }
2732

2733
                queryFunc := func(ctx context.Context, lastID int64,
×
2734
                        limit int32) ([]sqlc.ListChannelsPaginatedRow, error) {
×
2735

×
UNCOV
2736
                        return db.ListChannelsPaginated(
×
UNCOV
2737
                                ctx, sqlc.ListChannelsPaginatedParams{
×
2738
                                        Version: int16(lnwire.GossipVersion1),
×
2739
                                        ID:      lastID,
×
2740
                                        Limit:   limit,
×
2741
                                },
×
2742
                        )
×
2743
                }
×
2744

2745
                extractCursor := func(row sqlc.ListChannelsPaginatedRow) int64 {
×
2746
                        return row.ID
×
2747
                }
×
2748

UNCOV
2749
                return sqldb.ExecutePaginatedQuery(
×
2750
                        ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
×
2751
                        extractCursor, handleChannel,
×
2752
                )
×
UNCOV
2753
        }, func() {
×
2754
                edgePoints = nil
×
2755
        })
×
2756
        if err != nil {
×
2757
                return nil, fmt.Errorf("unable to fetch channel view: %w", err)
×
2758
        }
×
2759

2760
        return edgePoints, nil
×
2761
}
2762

2763
// PruneTip returns the block height and hash of the latest block that has been
2764
// used to prune channels in the graph. Knowing the "prune tip" allows callers
2765
// to tell if the graph is currently in sync with the current best known UTXO
2766
// state.
2767
//
2768
// NOTE: part of the V1Store interface.
UNCOV
2769
func (s *SQLStore) PruneTip() (*chainhash.Hash, uint32, error) {
×
UNCOV
2770
        var (
×
UNCOV
2771
                ctx       = context.TODO()
×
UNCOV
2772
                tipHash   chainhash.Hash
×
UNCOV
2773
                tipHeight uint32
×
2774
        )
×
2775
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2776
                pruneTip, err := db.GetPruneTip(ctx)
×
2777
                if errors.Is(err, sql.ErrNoRows) {
×
2778
                        return ErrGraphNeverPruned
×
2779
                } else if err != nil {
×
2780
                        return fmt.Errorf("unable to fetch prune tip: %w", err)
×
2781
                }
×
2782

2783
                tipHash = chainhash.Hash(pruneTip.BlockHash)
×
2784
                tipHeight = uint32(pruneTip.BlockHeight)
×
2785

×
2786
                return nil
×
2787
        }, sqldb.NoOpReset)
2788
        if err != nil {
×
2789
                return nil, 0, err
×
2790
        }
×
2791

UNCOV
2792
        return &tipHash, tipHeight, nil
×
2793
}
2794

2795
// pruneGraphNodes deletes any node in the DB that doesn't have a channel.
2796
//
2797
// NOTE: this prunes nodes across protocol versions. It will never prune the
2798
// source nodes.
2799
func (s *SQLStore) pruneGraphNodes(ctx context.Context,
UNCOV
2800
        db SQLQueries) ([]route.Vertex, error) {
×
UNCOV
2801

×
UNCOV
2802
        nodeKeys, err := db.DeleteUnconnectedNodes(ctx)
×
UNCOV
2803
        if err != nil {
×
UNCOV
2804
                return nil, fmt.Errorf("unable to delete unconnected "+
×
2805
                        "nodes: %w", err)
×
2806
        }
×
2807

2808
        prunedNodes := make([]route.Vertex, len(nodeKeys))
×
2809
        for i, nodeKey := range nodeKeys {
×
2810
                pub, err := route.NewVertexFromBytes(nodeKey)
×
2811
                if err != nil {
×
UNCOV
2812
                        return nil, fmt.Errorf("unable to parse pubkey "+
×
2813
                                "from bytes: %w", err)
×
2814
                }
×
2815

2816
                prunedNodes[i] = pub
×
2817
        }
2818

2819
        return prunedNodes, nil
×
2820
}
2821

2822
// DisconnectBlockAtHeight is used to indicate that the block specified
2823
// by the passed height has been disconnected from the main chain. This
2824
// will "rewind" the graph back to the height below, deleting channels
2825
// that are no longer confirmed from the graph. The prune log will be
2826
// set to the last prune height valid for the remaining chain.
2827
// Channels that were removed from the graph resulting from the
2828
// disconnected block are returned.
2829
//
2830
// NOTE: part of the V1Store interface.
2831
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
UNCOV
2832
        []*models.ChannelEdgeInfo, error) {
×
UNCOV
2833

×
UNCOV
2834
        ctx := context.TODO()
×
UNCOV
2835

×
UNCOV
2836
        var (
×
2837
                // Every channel having a ShortChannelID starting at 'height'
×
2838
                // will no longer be confirmed.
×
2839
                startShortChanID = lnwire.ShortChannelID{
×
2840
                        BlockHeight: height,
×
2841
                }
×
2842

×
2843
                // Delete everything after this height from the db up until the
×
2844
                // SCID alias range.
×
2845
                endShortChanID = aliasmgr.StartingAlias
×
2846

×
2847
                removedChans []*models.ChannelEdgeInfo
×
2848

×
2849
                chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
×
2850
                chanIDEnd   = channelIDToBytes(endShortChanID.ToUint64())
×
2851
        )
×
2852

×
2853
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2854
                rows, err := db.GetChannelsBySCIDRange(
×
2855
                        ctx, sqlc.GetChannelsBySCIDRangeParams{
×
2856
                                StartScid: chanIDStart,
×
2857
                                EndScid:   chanIDEnd,
×
2858
                        },
×
2859
                )
×
2860
                if err != nil {
×
2861
                        return fmt.Errorf("unable to fetch channels: %w", err)
×
2862
                }
×
2863

2864
                if len(rows) == 0 {
×
2865
                        // No channels to disconnect, but still clean up prune
×
2866
                        // log.
×
2867
                        return db.DeletePruneLogEntriesInRange(
×
UNCOV
2868
                                ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2869
                                        StartHeight: int64(height),
×
2870
                                        EndHeight: int64(
×
2871
                                                endShortChanID.BlockHeight,
×
2872
                                        ),
×
2873
                                },
×
2874
                        )
×
2875
                }
×
2876

2877
                // Batch build all channel edges for disconnection.
2878
                channelEdges, chanIDsToDelete, err := batchBuildChannelInfo(
×
2879
                        ctx, s.cfg, db, rows,
×
2880
                )
×
UNCOV
2881
                if err != nil {
×
UNCOV
2882
                        return err
×
2883
                }
×
2884

2885
                removedChans = channelEdges
×
2886

×
2887
                err = s.deleteChannels(ctx, db, chanIDsToDelete)
×
2888
                if err != nil {
×
UNCOV
2889
                        return fmt.Errorf("unable to delete channels: %w", err)
×
2890
                }
×
2891

2892
                return db.DeletePruneLogEntriesInRange(
×
2893
                        ctx, sqlc.DeletePruneLogEntriesInRangeParams{
×
2894
                                StartHeight: int64(height),
×
2895
                                EndHeight:   int64(endShortChanID.BlockHeight),
×
UNCOV
2896
                        },
×
2897
                )
×
2898
        }, func() {
×
2899
                removedChans = nil
×
2900
        })
×
2901
        if err != nil {
×
2902
                return nil, fmt.Errorf("unable to disconnect block at "+
×
2903
                        "height: %w", err)
×
2904
        }
×
2905

2906
        for _, channel := range removedChans {
×
2907
                s.rejectCache.remove(channel.ChannelID)
×
2908
                s.chanCache.remove(channel.ChannelID)
×
2909
        }
×
2910

2911
        return removedChans, nil
×
2912
}
2913

2914
// AddEdgeProof sets the proof of an existing edge in the graph database.
2915
//
2916
// NOTE: part of the V1Store interface.
2917
func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
UNCOV
2918
        proof *models.ChannelAuthProof) error {
×
UNCOV
2919

×
UNCOV
2920
        var (
×
UNCOV
2921
                ctx       = context.TODO()
×
UNCOV
2922
                scidBytes = channelIDToBytes(scid.ToUint64())
×
2923
        )
×
2924

×
2925
        err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2926
                res, err := db.AddV1ChannelProof(
×
2927
                        ctx, sqlc.AddV1ChannelProofParams{
×
2928
                                Scid:              scidBytes,
×
2929
                                Node1Signature:    proof.NodeSig1Bytes,
×
2930
                                Node2Signature:    proof.NodeSig2Bytes,
×
2931
                                Bitcoin1Signature: proof.BitcoinSig1Bytes,
×
2932
                                Bitcoin2Signature: proof.BitcoinSig2Bytes,
×
2933
                        },
×
2934
                )
×
2935
                if err != nil {
×
2936
                        return fmt.Errorf("unable to add edge proof: %w", err)
×
2937
                }
×
2938

2939
                n, err := res.RowsAffected()
×
2940
                if err != nil {
×
2941
                        return err
×
2942
                }
×
2943

2944
                if n == 0 {
×
2945
                        return fmt.Errorf("no rows affected when adding edge "+
×
2946
                                "proof for SCID %v", scid)
×
2947
                } else if n > 1 {
×
UNCOV
2948
                        return fmt.Errorf("multiple rows affected when adding "+
×
2949
                                "edge proof for SCID %v: %d rows affected",
×
2950
                                scid, n)
×
2951
                }
×
2952

2953
                return nil
×
2954
        }, sqldb.NoOpReset)
2955
        if err != nil {
×
2956
                return fmt.Errorf("unable to add edge proof: %w", err)
×
UNCOV
2957
        }
×
2958

UNCOV
2959
        return nil
×
2960
}
2961

2962
// PutClosedScid stores a SCID for a closed channel in the database. This is so
2963
// that we can ignore channel announcements that we know to be closed without
2964
// having to validate them and fetch a block.
2965
//
2966
// NOTE: part of the V1Store interface.
UNCOV
2967
func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
×
UNCOV
2968
        var (
×
UNCOV
2969
                ctx     = context.TODO()
×
UNCOV
2970
                chanIDB = channelIDToBytes(scid.ToUint64())
×
UNCOV
2971
        )
×
2972

×
2973
        return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
×
2974
                return db.InsertClosedChannel(ctx, chanIDB)
×
2975
        }, sqldb.NoOpReset)
×
2976
}
2977

2978
// IsClosedScid checks whether a channel identified by the passed in scid is
2979
// closed. This helps avoid having to perform expensive validation checks.
2980
//
2981
// NOTE: part of the V1Store interface.
UNCOV
2982
func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
×
UNCOV
2983
        var (
×
UNCOV
2984
                ctx      = context.TODO()
×
UNCOV
2985
                isClosed bool
×
UNCOV
2986
                chanIDB  = channelIDToBytes(scid.ToUint64())
×
2987
        )
×
2988
        err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
2989
                var err error
×
2990
                isClosed, err = db.IsClosedChannel(ctx, chanIDB)
×
2991
                if err != nil {
×
2992
                        return fmt.Errorf("unable to fetch closed channel: %w",
×
2993
                                err)
×
2994
                }
×
2995

2996
                return nil
×
2997
        }, sqldb.NoOpReset)
2998
        if err != nil {
×
2999
                return false, fmt.Errorf("unable to fetch closed channel: %w",
×
UNCOV
3000
                        err)
×
3001
        }
×
3002

3003
        return isClosed, nil
×
3004
}
3005

3006
// GraphSession will provide the call-back with access to a NodeTraverser
3007
// instance which can be used to perform queries against the channel graph.
3008
//
3009
// NOTE: part of the V1Store interface.
3010
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error,
UNCOV
3011
        reset func()) error {
×
UNCOV
3012

×
UNCOV
3013
        var ctx = context.TODO()
×
UNCOV
3014

×
UNCOV
3015
        return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
×
3016
                return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
×
3017
        }, reset)
×
3018
}
3019

3020
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
3021
// read only transaction for a consistent view of the graph.
3022
type sqlNodeTraverser struct {
3023
        db    SQLQueries
3024
        chain chainhash.Hash
3025
}
3026

3027
// A compile-time assertion to ensure that sqlNodeTraverser implements the
3028
// NodeTraverser interface.
3029
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
3030

3031
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
3032
func newSQLNodeTraverser(db SQLQueries,
UNCOV
3033
        chain chainhash.Hash) *sqlNodeTraverser {
×
UNCOV
3034

×
UNCOV
3035
        return &sqlNodeTraverser{
×
UNCOV
3036
                db:    db,
×
UNCOV
3037
                chain: chain,
×
3038
        }
×
3039
}
×
3040

3041
// ForEachNodeDirectedChannel calls the callback for every channel of the given
3042
// node.
3043
//
3044
// NOTE: Part of the NodeTraverser interface.
3045
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
UNCOV
3046
        cb func(channel *DirectedChannel) error, _ func()) error {
×
UNCOV
3047

×
UNCOV
3048
        ctx := context.TODO()
×
UNCOV
3049

×
UNCOV
3050
        return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
×
3051
}
×
3052

3053
// FetchNodeFeatures returns the features of the given node. If the node is
3054
// unknown, assume no additional features are supported.
3055
//
3056
// NOTE: Part of the NodeTraverser interface.
3057
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
UNCOV
3058
        *lnwire.FeatureVector, error) {
×
UNCOV
3059

×
UNCOV
3060
        ctx := context.TODO()
×
UNCOV
3061

×
UNCOV
3062
        return fetchNodeFeatures(ctx, s.db, nodePub)
×
3063
}
×
3064

3065
// forEachNodeDirectedChannel iterates through all channels of a given
3066
// node, executing the passed callback on the directed edge representing the
3067
// channel and its incoming policy. If the node is not found, no error is
3068
// returned.
3069
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
UNCOV
3070
        nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
×
UNCOV
3071

×
UNCOV
3072
        toNodeCallback := func() route.Vertex {
×
UNCOV
3073
                return nodePub
×
UNCOV
3074
        }
×
3075

3076
        dbID, err := db.GetNodeIDByPubKey(
×
3077
                ctx, sqlc.GetNodeIDByPubKeyParams{
×
3078
                        Version: int16(lnwire.GossipVersion1),
×
3079
                        PubKey:  nodePub[:],
×
UNCOV
3080
                },
×
3081
        )
×
3082
        if errors.Is(err, sql.ErrNoRows) {
×
3083
                return nil
×
3084
        } else if err != nil {
×
3085
                return fmt.Errorf("unable to fetch node: %w", err)
×
3086
        }
×
3087

3088
        rows, err := db.ListChannelsByNodeID(
×
3089
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3090
                        Version: int16(lnwire.GossipVersion1),
×
3091
                        NodeID1: dbID,
×
UNCOV
3092
                },
×
3093
        )
×
3094
        if err != nil {
×
3095
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3096
        }
×
3097

3098
        // Exit early if there are no channels for this node so we don't
3099
        // do the unnecessary feature fetching.
3100
        if len(rows) == 0 {
×
3101
                return nil
×
UNCOV
3102
        }
×
3103

UNCOV
3104
        features, err := getNodeFeatures(ctx, db, dbID)
×
3105
        if err != nil {
×
3106
                return fmt.Errorf("unable to fetch node features: %w", err)
×
3107
        }
×
3108

3109
        for _, row := range rows {
×
3110
                node1, node2, err := buildNodeVertices(
×
3111
                        row.Node1Pubkey, row.Node2Pubkey,
×
3112
                )
×
UNCOV
3113
                if err != nil {
×
3114
                        return fmt.Errorf("unable to build node vertices: %w",
×
3115
                                err)
×
3116
                }
×
3117

3118
                edge := buildCacheableChannelInfo(
×
3119
                        row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
×
3120
                        node1, node2,
×
3121
                )
×
UNCOV
3122

×
3123
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3124
                if err != nil {
×
3125
                        return err
×
3126
                }
×
3127

3128
                p1, p2, err := buildCachedChanPolicies(
×
3129
                        dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
3130
                )
×
3131
                if err != nil {
×
UNCOV
3132
                        return err
×
3133
                }
×
3134

3135
                // Determine the outgoing and incoming policy for this
3136
                // channel and node combo.
3137
                outPolicy, inPolicy := p1, p2
×
3138
                if p1 != nil && node2 == nodePub {
×
UNCOV
3139
                        outPolicy, inPolicy = p2, p1
×
UNCOV
3140
                } else if p2 != nil && node1 != nodePub {
×
UNCOV
3141
                        outPolicy, inPolicy = p2, p1
×
3142
                }
×
3143

3144
                var cachedInPolicy *models.CachedEdgePolicy
×
3145
                if inPolicy != nil {
×
3146
                        cachedInPolicy = inPolicy
×
3147
                        cachedInPolicy.ToNodePubKey = toNodeCallback
×
UNCOV
3148
                        cachedInPolicy.ToNodeFeatures = features
×
3149
                }
×
3150

3151
                directedChannel := &DirectedChannel{
×
3152
                        ChannelID:    edge.ChannelID,
×
3153
                        IsNode1:      nodePub == edge.NodeKey1Bytes,
×
3154
                        OtherNode:    edge.NodeKey2Bytes,
×
UNCOV
3155
                        Capacity:     edge.Capacity,
×
3156
                        OutPolicySet: outPolicy != nil,
×
3157
                        InPolicy:     cachedInPolicy,
×
3158
                }
×
3159
                if outPolicy != nil {
×
3160
                        outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3161
                                directedChannel.InboundFee = fee
×
3162
                        })
×
3163
                }
3164

3165
                if nodePub == edge.NodeKey2Bytes {
×
3166
                        directedChannel.OtherNode = edge.NodeKey1Bytes
×
3167
                }
×
3168

UNCOV
3169
                if err := cb(directedChannel); err != nil {
×
3170
                        return err
×
3171
                }
×
3172
        }
3173

3174
        return nil
×
3175
}
3176

3177
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
3178
// and executes the provided callback for each node. It does so via pagination
3179
// along with batch loading of the node feature bits.
3180
func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
3181
        db SQLQueries, processNode func(nodeID int64, nodePub route.Vertex,
UNCOV
3182
                features *lnwire.FeatureVector) error) error {
×
UNCOV
3183

×
UNCOV
3184
        handleNode := func(_ context.Context,
×
UNCOV
3185
                dbNode sqlc.ListNodeIDsAndPubKeysRow,
×
UNCOV
3186
                featureBits map[int64][]int) error {
×
3187

×
3188
                fv := lnwire.EmptyFeatureVector()
×
3189
                if features, exists := featureBits[dbNode.ID]; exists {
×
3190
                        for _, bit := range features {
×
3191
                                fv.Set(lnwire.FeatureBit(bit))
×
3192
                        }
×
3193
                }
3194

3195
                var pub route.Vertex
×
3196
                copy(pub[:], dbNode.PubKey)
×
3197

×
UNCOV
3198
                return processNode(dbNode.ID, pub, fv)
×
3199
        }
3200

3201
        queryFunc := func(ctx context.Context, lastID int64,
×
3202
                limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
×
3203

×
UNCOV
3204
                return db.ListNodeIDsAndPubKeys(
×
UNCOV
3205
                        ctx, sqlc.ListNodeIDsAndPubKeysParams{
×
3206
                                Version: int16(lnwire.GossipVersion1),
×
3207
                                ID:      lastID,
×
3208
                                Limit:   limit,
×
3209
                        },
×
3210
                )
×
3211
        }
×
3212

3213
        extractCursor := func(row sqlc.ListNodeIDsAndPubKeysRow) int64 {
×
3214
                return row.ID
×
3215
        }
×
3216

UNCOV
3217
        collectFunc := func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, error) {
×
3218
                return node.ID, nil
×
3219
        }
×
3220

UNCOV
3221
        batchQueryFunc := func(ctx context.Context,
×
3222
                nodeIDs []int64) (map[int64][]int, error) {
×
3223

×
3224
                return batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
UNCOV
3225
        }
×
3226

3227
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
3228
                ctx, cfg, int64(-1), queryFunc, extractCursor, collectFunc,
×
3229
                batchQueryFunc, handleNode,
×
3230
        )
×
3231
}
3232

3233
// forEachNodeChannel iterates through all channels of a node, executing
3234
// the passed callback on each. The call-back is provided with the channel's
3235
// edge information, the outgoing policy and the incoming policy for the
3236
// channel and node combo.
3237
func forEachNodeChannel(ctx context.Context, db SQLQueries,
3238
        cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
3239
                *models.ChannelEdgePolicy,
UNCOV
3240
                *models.ChannelEdgePolicy) error) error {
×
UNCOV
3241

×
UNCOV
3242
        // Get all the V1 channels for this node.
×
UNCOV
3243
        rows, err := db.ListChannelsByNodeID(
×
UNCOV
3244
                ctx, sqlc.ListChannelsByNodeIDParams{
×
3245
                        Version: int16(lnwire.GossipVersion1),
×
3246
                        NodeID1: id,
×
3247
                },
×
3248
        )
×
3249
        if err != nil {
×
3250
                return fmt.Errorf("unable to fetch channels: %w", err)
×
3251
        }
×
3252

3253
        // Collect all the channel and policy IDs.
3254
        var (
×
3255
                chanIDs   = make([]int64, 0, len(rows))
×
3256
                policyIDs = make([]int64, 0, 2*len(rows))
×
UNCOV
3257
        )
×
UNCOV
3258
        for _, row := range rows {
×
3259
                chanIDs = append(chanIDs, row.GraphChannel.ID)
×
3260

×
3261
                if row.Policy1ID.Valid {
×
3262
                        policyIDs = append(policyIDs, row.Policy1ID.Int64)
×
3263
                }
×
3264
                if row.Policy2ID.Valid {
×
3265
                        policyIDs = append(policyIDs, row.Policy2ID.Int64)
×
3266
                }
×
3267
        }
3268

3269
        batchData, err := batchLoadChannelData(
×
3270
                ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
×
3271
        )
×
UNCOV
3272
        if err != nil {
×
UNCOV
3273
                return fmt.Errorf("unable to batch load channel data: %w", err)
×
3274
        }
×
3275

3276
        // Call the call-back for each channel and its known policies.
3277
        for _, row := range rows {
×
3278
                node1, node2, err := buildNodeVertices(
×
3279
                        row.Node1Pubkey, row.Node2Pubkey,
×
UNCOV
3280
                )
×
UNCOV
3281
                if err != nil {
×
3282
                        return fmt.Errorf("unable to build node vertices: %w",
×
3283
                                err)
×
3284
                }
×
3285

3286
                edge, err := buildEdgeInfoWithBatchData(
×
3287
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
3288
                        batchData,
×
3289
                )
×
UNCOV
3290
                if err != nil {
×
3291
                        return fmt.Errorf("unable to build channel info: %w",
×
3292
                                err)
×
3293
                }
×
3294

3295
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
3296
                if err != nil {
×
3297
                        return fmt.Errorf("unable to extract channel "+
×
3298
                                "policies: %w", err)
×
UNCOV
3299
                }
×
3300

3301
                p1, p2, err := buildChanPoliciesWithBatchData(
×
3302
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
3303
                )
×
3304
                if err != nil {
×
UNCOV
3305
                        return fmt.Errorf("unable to build channel "+
×
3306
                                "policies: %w", err)
×
3307
                }
×
3308

3309
                // Determine the outgoing and incoming policy for this
3310
                // channel and node combo.
3311
                p1ToNode := row.GraphChannel.NodeID2
×
3312
                p2ToNode := row.GraphChannel.NodeID1
×
UNCOV
3313
                outPolicy, inPolicy := p1, p2
×
UNCOV
3314
                if (p1 != nil && p1ToNode == id) ||
×
UNCOV
3315
                        (p2 != nil && p2ToNode != id) {
×
3316

×
3317
                        outPolicy, inPolicy = p2, p1
×
3318
                }
×
3319

3320
                if err := cb(edge, outPolicy, inPolicy); err != nil {
×
3321
                        return err
×
3322
                }
×
3323
        }
3324

3325
        return nil
×
3326
}
3327

3328
// updateChanEdgePolicy upserts the channel policy info we have stored for
3329
// a channel we already know of.
3330
func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
3331
        edge *models.ChannelEdgePolicy) (route.Vertex, route.Vertex, bool,
UNCOV
3332
        error) {
×
UNCOV
3333

×
UNCOV
3334
        var (
×
UNCOV
3335
                node1Pub, node2Pub route.Vertex
×
UNCOV
3336
                isNode1            bool
×
3337
                chanIDB            = channelIDToBytes(edge.ChannelID)
×
3338
        )
×
3339

×
3340
        // Check that this edge policy refers to a channel that we already
×
3341
        // know of. We do this explicitly so that we can return the appropriate
×
3342
        // ErrEdgeNotFound error if the channel doesn't exist, rather than
×
3343
        // abort the transaction which would abort the entire batch.
×
3344
        dbChan, err := tx.GetChannelAndNodesBySCID(
×
3345
                ctx, sqlc.GetChannelAndNodesBySCIDParams{
×
3346
                        Scid:    chanIDB,
×
3347
                        Version: int16(lnwire.GossipVersion1),
×
3348
                },
×
3349
        )
×
3350
        if errors.Is(err, sql.ErrNoRows) {
×
3351
                return node1Pub, node2Pub, false, ErrEdgeNotFound
×
3352
        } else if err != nil {
×
3353
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
3354
                        "fetch channel(%v): %w", edge.ChannelID, err)
×
3355
        }
×
3356

3357
        copy(node1Pub[:], dbChan.Node1PubKey)
×
3358
        copy(node2Pub[:], dbChan.Node2PubKey)
×
3359

×
3360
        // Figure out which node this edge is from.
×
UNCOV
3361
        isNode1 = edge.ChannelFlags&lnwire.ChanUpdateDirection == 0
×
3362
        nodeID := dbChan.NodeID1
×
3363
        if !isNode1 {
×
3364
                nodeID = dbChan.NodeID2
×
3365
        }
×
3366

3367
        var (
×
3368
                inboundBase sql.NullInt64
×
3369
                inboundRate sql.NullInt64
×
3370
        )
×
UNCOV
3371
        edge.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
3372
                inboundRate = sqldb.SQLInt64(fee.FeeRate)
×
3373
                inboundBase = sqldb.SQLInt64(fee.BaseFee)
×
3374
        })
×
3375

3376
        id, err := tx.UpsertEdgePolicy(ctx, sqlc.UpsertEdgePolicyParams{
×
3377
                Version:     int16(lnwire.GossipVersion1),
×
3378
                ChannelID:   dbChan.ID,
×
3379
                NodeID:      nodeID,
×
UNCOV
3380
                Timelock:    int32(edge.TimeLockDelta),
×
3381
                FeePpm:      int64(edge.FeeProportionalMillionths),
×
3382
                BaseFeeMsat: int64(edge.FeeBaseMSat),
×
3383
                MinHtlcMsat: int64(edge.MinHTLC),
×
3384
                LastUpdate:  sqldb.SQLInt64(edge.LastUpdate.Unix()),
×
3385
                Disabled: sql.NullBool{
×
3386
                        Valid: true,
×
3387
                        Bool:  edge.IsDisabled(),
×
3388
                },
×
3389
                MaxHtlcMsat: sql.NullInt64{
×
3390
                        Valid: edge.MessageFlags.HasMaxHtlc(),
×
3391
                        Int64: int64(edge.MaxHTLC),
×
3392
                },
×
3393
                MessageFlags:            sqldb.SQLInt16(edge.MessageFlags),
×
3394
                ChannelFlags:            sqldb.SQLInt16(edge.ChannelFlags),
×
3395
                InboundBaseFeeMsat:      inboundBase,
×
3396
                InboundFeeRateMilliMsat: inboundRate,
×
3397
                Signature:               edge.SigBytes,
×
3398
        })
×
3399
        if err != nil {
×
3400
                return node1Pub, node2Pub, isNode1,
×
3401
                        fmt.Errorf("unable to upsert edge policy: %w", err)
×
3402
        }
×
3403

3404
        // Convert the flat extra opaque data into a map of TLV types to
3405
        // values.
3406
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
3407
        if err != nil {
×
UNCOV
3408
                return node1Pub, node2Pub, false, fmt.Errorf("unable to "+
×
UNCOV
3409
                        "marshal extra opaque data: %w", err)
×
UNCOV
3410
        }
×
3411

3412
        // Update the channel policy's extra signed fields.
3413
        err = upsertChanPolicyExtraSignedFields(ctx, tx, id, extra)
×
3414
        if err != nil {
×
3415
                return node1Pub, node2Pub, false, fmt.Errorf("inserting chan "+
×
UNCOV
3416
                        "policy extra TLVs: %w", err)
×
UNCOV
3417
        }
×
3418

3419
        return node1Pub, node2Pub, isNode1, nil
×
3420
}
3421

3422
// getNodeByPubKey attempts to look up a target node by its public key.
3423
func getNodeByPubKey(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
3424
        pubKey route.Vertex) (int64, *models.Node, error) {
×
UNCOV
3425

×
UNCOV
3426
        dbNode, err := db.GetNodeByPubKey(
×
UNCOV
3427
                ctx, sqlc.GetNodeByPubKeyParams{
×
UNCOV
3428
                        Version: int16(lnwire.GossipVersion1),
×
3429
                        PubKey:  pubKey[:],
×
3430
                },
×
3431
        )
×
3432
        if errors.Is(err, sql.ErrNoRows) {
×
3433
                return 0, nil, ErrGraphNodeNotFound
×
3434
        } else if err != nil {
×
3435
                return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
×
3436
        }
×
3437

3438
        node, err := buildNode(ctx, cfg, db, dbNode)
×
3439
        if err != nil {
×
3440
                return 0, nil, fmt.Errorf("unable to build node: %w", err)
×
3441
        }
×
3442

3443
        return dbNode.ID, node, nil
×
3444
}
3445

3446
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3447
// provided parameters.
3448
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
UNCOV
3449
        node2Pub route.Vertex) *models.CachedEdgeInfo {
×
UNCOV
3450

×
UNCOV
3451
        return &models.CachedEdgeInfo{
×
UNCOV
3452
                ChannelID:     byteOrder.Uint64(scid),
×
UNCOV
3453
                NodeKey1Bytes: node1Pub,
×
3454
                NodeKey2Bytes: node2Pub,
×
3455
                Capacity:      btcutil.Amount(capacity),
×
3456
        }
×
3457
}
×
3458

3459
// buildNode constructs a Node instance from the given database node
3460
// record. The node's features, addresses and extra signed fields are also
3461
// fetched from the database and set on the node.
3462
func buildNode(ctx context.Context, cfg *sqldb.QueryConfig, db SQLQueries,
UNCOV
3463
        dbNode sqlc.GraphNode) (*models.Node, error) {
×
UNCOV
3464

×
UNCOV
3465
        data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
×
UNCOV
3466
        if err != nil {
×
UNCOV
3467
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
3468
                        err)
×
3469
        }
×
3470

3471
        return buildNodeWithBatchData(dbNode, data)
×
3472
}
3473

3474
// buildNodeWithBatchData builds a models.Node instance
3475
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
3476
// features/addresses/extra fields, then the corresponding fields are expected
3477
// to be present in the batchNodeData.
3478
func buildNodeWithBatchData(dbNode sqlc.GraphNode,
UNCOV
3479
        batchData *batchNodeData) (*models.Node, error) {
×
UNCOV
3480

×
UNCOV
3481
        if dbNode.Version != int16(lnwire.GossipVersion1) {
×
UNCOV
3482
                return nil, fmt.Errorf("unsupported node version: %d",
×
UNCOV
3483
                        dbNode.Version)
×
3484
        }
×
3485

3486
        var pub [33]byte
×
3487
        copy(pub[:], dbNode.PubKey)
×
3488

×
3489
        node := models.NewV1ShellNode(pub)
×
UNCOV
3490

×
3491
        if len(dbNode.Signature) == 0 {
×
3492
                return node, nil
×
3493
        }
×
3494

3495
        node.AuthSigBytes = dbNode.Signature
×
3496

×
3497
        if dbNode.Alias.Valid {
×
3498
                node.Alias = fn.Some(dbNode.Alias.String)
×
3499
        }
×
3500
        if dbNode.LastUpdate.Valid {
×
3501
                node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
×
3502
        }
×
3503

3504
        var err error
×
3505
        if dbNode.Color.Valid {
×
3506
                nodeColor, err := DecodeHexColor(dbNode.Color.String)
×
3507
                if err != nil {
×
3508
                        return nil, fmt.Errorf("unable to decode color: %w",
×
3509
                                err)
×
3510
                }
×
3511

3512
                node.Color = fn.Some(nodeColor)
×
3513
        }
3514

3515
        // Use preloaded features.
UNCOV
3516
        if features, exists := batchData.features[dbNode.ID]; exists {
×
UNCOV
3517
                fv := lnwire.EmptyFeatureVector()
×
UNCOV
3518
                for _, bit := range features {
×
3519
                        fv.Set(lnwire.FeatureBit(bit))
×
3520
                }
×
3521
                node.Features = fv
×
3522
        }
3523

3524
        // Use preloaded addresses.
UNCOV
3525
        addresses, exists := batchData.addresses[dbNode.ID]
×
UNCOV
3526
        if exists && len(addresses) > 0 {
×
UNCOV
3527
                node.Addresses, err = buildNodeAddresses(addresses)
×
3528
                if err != nil {
×
3529
                        return nil, fmt.Errorf("unable to build addresses "+
×
3530
                                "for node(%d): %w", dbNode.ID, err)
×
3531
                }
×
3532
        }
3533

3534
        // Use preloaded extra fields.
UNCOV
3535
        if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
×
UNCOV
3536
                recs, err := lnwire.CustomRecords(extraFields).Serialize()
×
UNCOV
3537
                if err != nil {
×
3538
                        return nil, fmt.Errorf("unable to serialize extra "+
×
3539
                                "signed fields: %w", err)
×
3540
                }
×
3541
                if len(recs) != 0 {
×
3542
                        node.ExtraOpaqueData = recs
×
3543
                }
×
3544
        }
3545

3546
        return node, nil
×
3547
}
3548

3549
// forEachNodeInBatch fetches all nodes in the provided batch, builds them
3550
// with the preloaded data, and executes the provided callback for each node.
3551
func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig,
3552
        db SQLQueries, nodes []sqlc.GraphNode,
UNCOV
3553
        cb func(dbID int64, node *models.Node) error) error {
×
UNCOV
3554

×
UNCOV
3555
        // Extract node IDs for batch loading.
×
3556
        nodeIDs := make([]int64, len(nodes))
×
3557
        for i, node := range nodes {
×
3558
                nodeIDs[i] = node.ID
×
3559
        }
×
3560

3561
        // Batch load all related data for this page.
3562
        batchData, err := batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
UNCOV
3563
        if err != nil {
×
UNCOV
3564
                return fmt.Errorf("unable to batch load node data: %w", err)
×
3565
        }
×
3566

3567
        for _, dbNode := range nodes {
×
3568
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
UNCOV
3569
                if err != nil {
×
3570
                        return fmt.Errorf("unable to build node(id=%d): %w",
×
3571
                                dbNode.ID, err)
×
3572
                }
×
3573

3574
                if err := cb(dbNode.ID, node); err != nil {
×
3575
                        return fmt.Errorf("callback failed for node(id=%d): %w",
×
UNCOV
3576
                                dbNode.ID, err)
×
3577
                }
×
3578
        }
3579

3580
        return nil
×
3581
}
3582

3583
// getNodeFeatures fetches the feature bits and constructs the feature vector
3584
// for a node with the given DB ID.
3585
func getNodeFeatures(ctx context.Context, db SQLQueries,
UNCOV
3586
        nodeID int64) (*lnwire.FeatureVector, error) {
×
UNCOV
3587

×
UNCOV
3588
        rows, err := db.GetNodeFeatures(ctx, nodeID)
×
3589
        if err != nil {
×
3590
                return nil, fmt.Errorf("unable to get node(%d) features: %w",
×
3591
                        nodeID, err)
×
3592
        }
×
3593

3594
        features := lnwire.EmptyFeatureVector()
×
3595
        for _, feature := range rows {
×
UNCOV
3596
                features.Set(lnwire.FeatureBit(feature.FeatureBit))
×
3597
        }
×
3598

3599
        return features, nil
×
3600
}
3601

3602
// upsertNode upserts the node record into the database. If the node already
3603
// exists, then the node's information is updated. If the node doesn't exist,
3604
// then a new node is created. The node's features, addresses and extra TLV
3605
// types are also updated. The node's DB ID is returned.
3606
func upsertNode(ctx context.Context, db SQLQueries,
UNCOV
3607
        node *models.Node) (int64, error) {
×
UNCOV
3608

×
UNCOV
3609
        params := sqlc.UpsertNodeParams{
×
3610
                Version: int16(lnwire.GossipVersion1),
×
3611
                PubKey:  node.PubKeyBytes[:],
×
3612
        }
×
3613

×
3614
        if node.HaveAnnouncement() {
×
3615
                switch node.Version {
×
3616
                case lnwire.GossipVersion1:
×
3617
                        params.LastUpdate = sqldb.SQLInt64(
×
3618
                                node.LastUpdate.Unix(),
×
3619
                        )
×
3620

3621
                case lnwire.GossipVersion2:
×
3622

UNCOV
3623
                default:
×
3624
                        return 0, fmt.Errorf("unknown gossip version: %d",
×
3625
                                node.Version)
×
3626
                }
3627

3628
                node.Color.WhenSome(func(rgba color.RGBA) {
×
UNCOV
3629
                        params.Color = sqldb.SQLStrValid(EncodeHexColor(rgba))
×
UNCOV
3630
                })
×
3631
                node.Alias.WhenSome(func(s string) {
×
3632
                        params.Alias = sqldb.SQLStrValid(s)
×
3633
                })
×
3634

UNCOV
3635
                params.Signature = node.AuthSigBytes
×
3636
        }
3637

3638
        nodeID, err := db.UpsertNode(ctx, params)
×
3639
        if err != nil {
×
UNCOV
3640
                return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
×
UNCOV
3641
                        err)
×
3642
        }
×
3643

3644
        // We can exit here if we don't have the announcement yet.
3645
        if !node.HaveAnnouncement() {
×
UNCOV
3646
                return nodeID, nil
×
UNCOV
3647
        }
×
3648

3649
        // Update the node's features.
3650
        err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
×
3651
        if err != nil {
×
3652
                return 0, fmt.Errorf("inserting node features: %w", err)
×
3653
        }
×
3654

3655
        // Update the node's addresses.
3656
        err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
×
3657
        if err != nil {
×
3658
                return 0, fmt.Errorf("inserting node addresses: %w", err)
×
3659
        }
×
3660

3661
        // Convert the flat extra opaque data into a map of TLV types to
3662
        // values.
UNCOV
3663
        extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
×
UNCOV
3664
        if err != nil {
×
UNCOV
3665
                return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
×
UNCOV
3666
                        err)
×
UNCOV
3667
        }
×
3668

3669
        // Update the node's extra signed fields.
3670
        err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
×
3671
        if err != nil {
×
3672
                return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
×
3673
        }
×
3674

3675
        return nodeID, nil
×
3676
}
3677

3678
// upsertNodeFeatures updates the node's features node_features table. This
3679
// includes deleting any feature bits no longer present and inserting any new
3680
// feature bits. If the feature bit does not yet exist in the features table,
3681
// then an entry is created in that table first.
3682
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
UNCOV
3683
        features *lnwire.FeatureVector) error {
×
UNCOV
3684

×
UNCOV
3685
        // Get any existing features for the node.
×
UNCOV
3686
        existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
×
UNCOV
3687
        if err != nil && !errors.Is(err, sql.ErrNoRows) {
×
3688
                return err
×
3689
        }
×
3690

3691
        // Copy the nodes latest set of feature bits.
3692
        newFeatures := make(map[int32]struct{})
×
3693
        if features != nil {
×
UNCOV
3694
                for feature := range features.Features() {
×
UNCOV
3695
                        newFeatures[int32(feature)] = struct{}{}
×
UNCOV
3696
                }
×
3697
        }
3698

3699
        // For any current feature that already exists in the DB, remove it from
3700
        // the in-memory map. For any existing feature that does not exist in
3701
        // the in-memory map, delete it from the database.
3702
        for _, feature := range existingFeatures {
×
3703
                // The feature is still present, so there are no updates to be
×
3704
                // made.
×
3705
                if _, ok := newFeatures[feature.FeatureBit]; ok {
×
3706
                        delete(newFeatures, feature.FeatureBit)
×
UNCOV
3707
                        continue
×
3708
                }
3709

3710
                // The feature is no longer present, so we remove it from the
3711
                // database.
3712
                err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
×
3713
                        NodeID:     nodeID,
×
3714
                        FeatureBit: feature.FeatureBit,
×
3715
                })
×
3716
                if err != nil {
×
3717
                        return fmt.Errorf("unable to delete node(%d) "+
×
3718
                                "feature(%v): %w", nodeID, feature.FeatureBit,
×
3719
                                err)
×
UNCOV
3720
                }
×
3721
        }
3722

3723
        // Any remaining entries in newFeatures are new features that need to be
3724
        // added to the database for the first time.
UNCOV
3725
        for feature := range newFeatures {
×
UNCOV
3726
                err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
×
3727
                        NodeID:     nodeID,
×
3728
                        FeatureBit: feature,
×
3729
                })
×
3730
                if err != nil {
×
3731
                        return fmt.Errorf("unable to insert node(%d) "+
×
3732
                                "feature(%v): %w", nodeID, feature, err)
×
3733
                }
×
3734
        }
3735

3736
        return nil
×
3737
}
3738

3739
// fetchNodeFeatures fetches the features for a node with the given public key.
3740
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
3741
        nodePub route.Vertex) (*lnwire.FeatureVector, error) {
×
3742

×
3743
        rows, err := queries.GetNodeFeaturesByPubKey(
×
UNCOV
3744
                ctx, sqlc.GetNodeFeaturesByPubKeyParams{
×
3745
                        PubKey:  nodePub[:],
×
UNCOV
3746
                        Version: int16(lnwire.GossipVersion1),
×
UNCOV
3747
                },
×
UNCOV
3748
        )
×
UNCOV
3749
        if err != nil {
×
UNCOV
3750
                return nil, fmt.Errorf("unable to get node(%s) features: %w",
×
UNCOV
3751
                        nodePub, err)
×
UNCOV
3752
        }
×
3753

UNCOV
3754
        features := lnwire.EmptyFeatureVector()
×
UNCOV
3755
        for _, bit := range rows {
×
UNCOV
3756
                features.Set(lnwire.FeatureBit(bit))
×
UNCOV
3757
        }
×
3758

UNCOV
3759
        return features, nil
×
3760
}
3761

3762
// dbAddressType is an enum type that represents the different address types
3763
// that we store in the node_addresses table. The address type determines how
3764
// the address is to be serialised/deserialize.
3765
type dbAddressType uint8
3766

3767
const (
3768
        addressTypeIPv4   dbAddressType = 1
3769
        addressTypeIPv6   dbAddressType = 2
3770
        addressTypeTorV2  dbAddressType = 3
3771
        addressTypeTorV3  dbAddressType = 4
3772
        addressTypeDNS    dbAddressType = 5
3773
        addressTypeOpaque dbAddressType = math.MaxInt8
3774
)
3775

3776
// collectAddressRecords collects the addresses from the provided
3777
// net.Addr slice and returns a map of dbAddressType to a slice of address
3778
// strings.
3779
func collectAddressRecords(addresses []net.Addr) (map[dbAddressType][]string,
UNCOV
3780
        error) {
×
3781

×
3782
        // Copy the nodes latest set of addresses.
×
3783
        newAddresses := map[dbAddressType][]string{
×
3784
                addressTypeIPv4:   {},
×
3785
                addressTypeIPv6:   {},
×
3786
                addressTypeTorV2:  {},
×
3787
                addressTypeTorV3:  {},
×
3788
                addressTypeDNS:    {},
×
3789
                addressTypeOpaque: {},
×
3790
        }
×
3791
        addAddr := func(t dbAddressType, addr net.Addr) {
×
UNCOV
3792
                newAddresses[t] = append(newAddresses[t], addr.String())
×
3793
        }
×
3794

3795
        for _, address := range addresses {
×
3796
                switch addr := address.(type) {
×
3797
                case *net.TCPAddr:
×
3798
                        if ip4 := addr.IP.To4(); ip4 != nil {
×
3799
                                addAddr(addressTypeIPv4, addr)
×
3800
                        } else if ip6 := addr.IP.To16(); ip6 != nil {
×
3801
                                addAddr(addressTypeIPv6, addr)
×
UNCOV
3802
                        } else {
×
UNCOV
3803
                                return nil, fmt.Errorf("unhandled IP "+
×
3804
                                        "address: %v", addr)
×
3805
                        }
×
3806

3807
                case *tor.OnionAddr:
×
3808
                        switch len(addr.OnionService) {
×
UNCOV
3809
                        case tor.V2Len:
×
3810
                                addAddr(addressTypeTorV2, addr)
×
3811
                        case tor.V3Len:
×
3812
                                addAddr(addressTypeTorV3, addr)
×
UNCOV
3813
                        default:
×
UNCOV
3814
                                return nil, fmt.Errorf("invalid length for " +
×
UNCOV
3815
                                        "a tor address")
×
3816
                        }
3817

UNCOV
3818
                case *lnwire.DNSAddress:
×
UNCOV
3819
                        addAddr(addressTypeDNS, addr)
×
3820

UNCOV
3821
                case *lnwire.OpaqueAddrs:
×
UNCOV
3822
                        addAddr(addressTypeOpaque, addr)
×
3823

UNCOV
3824
                default:
×
UNCOV
3825
                        return nil, fmt.Errorf("unhandled address type: %T",
×
3826
                                addr)
×
3827
                }
3828
        }
3829

3830
        return newAddresses, nil
×
3831
}
3832

3833
// upsertNodeAddresses updates the node's addresses in the database. This
3834
// includes deleting any existing addresses and inserting the new set of
3835
// addresses. The deletion is necessary since the ordering of the addresses may
3836
// change, and we need to ensure that the database reflects the latest set of
3837
// addresses so that at the time of reconstructing the node announcement, the
3838
// order is preserved and the signature over the message remains valid.
3839
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3840
        addresses []net.Addr) error {
×
UNCOV
3841

×
UNCOV
3842
        // Delete any existing addresses for the node. This is required since
×
UNCOV
3843
        // even if the new set of addresses is the same, the ordering may have
×
3844
        // changed for a given address type.
×
3845
        err := db.DeleteNodeAddresses(ctx, nodeID)
×
3846
        if err != nil {
×
3847
                return fmt.Errorf("unable to delete node(%d) addresses: %w",
×
3848
                        nodeID, err)
×
3849
        }
×
3850

3851
        newAddresses, err := collectAddressRecords(addresses)
×
3852
        if err != nil {
×
3853
                return err
×
3854
        }
×
3855

3856
        // Any remaining entries in newAddresses are new addresses that need to
3857
        // be added to the database for the first time.
3858
        for addrType, addrList := range newAddresses {
×
UNCOV
3859
                for position, addr := range addrList {
×
UNCOV
3860
                        err := db.UpsertNodeAddress(
×
UNCOV
3861
                                ctx, sqlc.UpsertNodeAddressParams{
×
3862
                                        NodeID:   nodeID,
×
UNCOV
3863
                                        Type:     int16(addrType),
×
UNCOV
3864
                                        Address:  addr,
×
UNCOV
3865
                                        Position: int32(position),
×
UNCOV
3866
                                },
×
3867
                        )
×
3868
                        if err != nil {
×
3869
                                return fmt.Errorf("unable to insert "+
×
3870
                                        "node(%d) address(%v): %w", nodeID,
×
3871
                                        addr, err)
×
3872
                        }
×
3873
                }
3874
        }
3875

3876
        return nil
×
3877
}
3878

3879
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3880
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3881
        error) {
×
3882

×
3883
        // GetNodeAddresses ensures that the addresses for a given type are
×
3884
        // returned in the same order as they were inserted.
×
UNCOV
3885
        rows, err := db.GetNodeAddresses(ctx, id)
×
3886
        if err != nil {
×
UNCOV
3887
                return nil, err
×
UNCOV
3888
        }
×
3889

UNCOV
3890
        addresses := make([]net.Addr, 0, len(rows))
×
3891
        for _, row := range rows {
×
3892
                address := row.Address
×
3893

×
UNCOV
3894
                addr, err := parseAddress(dbAddressType(row.Type), address)
×
3895
                if err != nil {
×
UNCOV
3896
                        return nil, fmt.Errorf("unable to parse address "+
×
UNCOV
3897
                                "for node(%d): %v: %w", id, address, err)
×
UNCOV
3898
                }
×
3899

UNCOV
3900
                addresses = append(addresses, addr)
×
3901
        }
3902

3903
        // If we have no addresses, then we'll return nil instead of an
3904
        // empty slice.
3905
        if len(addresses) == 0 {
×
3906
                addresses = nil
×
3907
        }
×
3908

UNCOV
3909
        return addresses, nil
×
3910
}
3911

3912
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
3913
// database. This includes updating any existing types, inserting any new types,
3914
// and deleting any types that are no longer present.
3915
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
UNCOV
3916
        nodeID int64, extraFields map[uint64][]byte) error {
×
UNCOV
3917

×
UNCOV
3918
        // Get any existing extra signed fields for the node.
×
3919
        existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
×
3920
        if err != nil {
×
3921
                return err
×
3922
        }
×
3923

3924
        // Make a lookup map of the existing field types so that we can use it
3925
        // to keep track of any fields we should delete.
3926
        m := make(map[uint64]bool)
×
3927
        for _, field := range existingFields {
×
3928
                m[uint64(field.Type)] = true
×
3929
        }
×
3930

3931
        // For all the new fields, we'll upsert them and remove them from the
3932
        // map of existing fields.
UNCOV
3933
        for tlvType, value := range extraFields {
×
3934
                err = db.UpsertNodeExtraType(
×
UNCOV
3935
                        ctx, sqlc.UpsertNodeExtraTypeParams{
×
UNCOV
3936
                                NodeID: nodeID,
×
UNCOV
3937
                                Type:   int64(tlvType),
×
UNCOV
3938
                                Value:  value,
×
3939
                        },
×
3940
                )
×
3941
                if err != nil {
×
3942
                        return fmt.Errorf("unable to upsert node(%d) extra "+
×
3943
                                "signed field(%v): %w", nodeID, tlvType, err)
×
3944
                }
×
3945

3946
                // Remove the field from the map of existing fields if it was
3947
                // present.
3948
                delete(m, tlvType)
×
3949
        }
3950

3951
        // For all the fields that are left in the map of existing fields, we'll
3952
        // delete them as they are no longer present in the new set of fields.
UNCOV
3953
        for tlvType := range m {
×
UNCOV
3954
                err = db.DeleteExtraNodeType(
×
UNCOV
3955
                        ctx, sqlc.DeleteExtraNodeTypeParams{
×
UNCOV
3956
                                NodeID: nodeID,
×
UNCOV
3957
                                Type:   int64(tlvType),
×
UNCOV
3958
                        },
×
UNCOV
3959
                )
×
UNCOV
3960
                if err != nil {
×
UNCOV
3961
                        return fmt.Errorf("unable to delete node(%d) extra "+
×
UNCOV
3962
                                "signed field(%v): %w", nodeID, tlvType, err)
×
UNCOV
3963
                }
×
3964
        }
3965

UNCOV
3966
        return nil
×
3967
}
3968

3969
// srcNodeInfo holds the information about the source node of the graph.
3970
type srcNodeInfo struct {
3971
        // id is the DB level ID of the source node entry in the "nodes" table.
3972
        id int64
3973

3974
        // pub is the public key of the source node.
3975
        pub route.Vertex
3976
}
3977

3978
// sourceNode returns the DB node ID and pub key of the source node for the
3979
// specified protocol version.
3980
func (s *SQLStore) getSourceNode(ctx context.Context, db SQLQueries,
3981
        version lnwire.GossipVersion) (int64, route.Vertex, error) {
×
3982

×
3983
        s.srcNodeMu.Lock()
×
3984
        defer s.srcNodeMu.Unlock()
×
UNCOV
3985

×
3986
        // If we already have the source node ID and pub key cached, then
×
3987
        // return them.
×
3988
        if info, ok := s.srcNodes[version]; ok {
×
3989
                return info.id, info.pub, nil
×
3990
        }
×
3991

UNCOV
3992
        var pubKey route.Vertex
×
3993

×
3994
        nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
×
3995
        if err != nil {
×
3996
                return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
×
3997
                        err)
×
3998
        }
×
3999

4000
        if len(nodes) == 0 {
×
UNCOV
4001
                return 0, pubKey, ErrSourceNodeNotSet
×
UNCOV
4002
        } else if len(nodes) > 1 {
×
UNCOV
4003
                return 0, pubKey, fmt.Errorf("multiple source nodes for "+
×
UNCOV
4004
                        "protocol %s found", version)
×
UNCOV
4005
        }
×
4006

4007
        copy(pubKey[:], nodes[0].PubKey)
×
4008

×
4009
        s.srcNodes[version] = &srcNodeInfo{
×
4010
                id:  nodes[0].NodeID,
×
4011
                pub: pubKey,
×
4012
        }
×
UNCOV
4013

×
UNCOV
4014
        return nodes[0].NodeID, pubKey, nil
×
4015
}
4016

4017
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
4018
// This then produces a map from TLV type to value. If the input is not a
4019
// valid TLV stream, then an error is returned.
4020
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
×
4021
        r := bytes.NewReader(data)
×
4022

×
UNCOV
4023
        tlvStream, err := tlv.NewStream()
×
4024
        if err != nil {
×
4025
                return nil, err
×
4026
        }
×
4027

4028
        // Since ExtraOpaqueData is provided by a potentially malicious peer,
4029
        // pass it into the P2P decoding variant.
UNCOV
4030
        parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
×
UNCOV
4031
        if err != nil {
×
UNCOV
4032
                return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
×
UNCOV
4033
        }
×
4034
        if len(parsedTypes) == 0 {
×
4035
                return nil, nil
×
4036
        }
×
4037

4038
        records := make(map[uint64][]byte)
×
4039
        for k, v := range parsedTypes {
×
4040
                records[uint64(k)] = v
×
4041
        }
×
4042

4043
        return records, nil
×
4044
}
4045

4046
// insertChannel inserts a new channel record into the database.
4047
func insertChannel(ctx context.Context, db SQLQueries,
4048
        edge *models.ChannelEdgeInfo) error {
×
4049

×
4050
        // Make sure that at least a "shell" entry for each node is present in
×
4051
        // the nodes table.
×
UNCOV
4052
        node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
×
4053
        if err != nil {
×
4054
                return fmt.Errorf("unable to create shell node: %w", err)
×
4055
        }
×
4056

4057
        node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
×
4058
        if err != nil {
×
4059
                return fmt.Errorf("unable to create shell node: %w", err)
×
4060
        }
×
4061

4062
        var capacity sql.NullInt64
×
4063
        if edge.Capacity != 0 {
×
4064
                capacity = sqldb.SQLInt64(int64(edge.Capacity))
×
4065
        }
×
4066

4067
        createParams := sqlc.CreateChannelParams{
×
4068
                Version:     int16(lnwire.GossipVersion1),
×
4069
                Scid:        channelIDToBytes(edge.ChannelID),
×
4070
                NodeID1:     node1DBID,
×
4071
                NodeID2:     node2DBID,
×
UNCOV
4072
                Outpoint:    edge.ChannelPoint.String(),
×
UNCOV
4073
                Capacity:    capacity,
×
4074
                BitcoinKey1: edge.BitcoinKey1Bytes[:],
×
4075
                BitcoinKey2: edge.BitcoinKey2Bytes[:],
×
4076
        }
×
4077

×
UNCOV
4078
        if edge.AuthProof != nil {
×
UNCOV
4079
                proof := edge.AuthProof
×
4080

×
4081
                createParams.Node1Signature = proof.NodeSig1Bytes
×
4082
                createParams.Node2Signature = proof.NodeSig2Bytes
×
4083
                createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
×
4084
                createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
×
4085
        }
×
4086

4087
        // Insert the new channel record.
4088
        dbChanID, err := db.CreateChannel(ctx, createParams)
×
4089
        if err != nil {
×
4090
                return err
×
UNCOV
4091
        }
×
4092

4093
        // Insert any channel features.
4094
        for feature := range edge.Features.Features() {
×
4095
                err = db.InsertChannelFeature(
×
4096
                        ctx, sqlc.InsertChannelFeatureParams{
×
4097
                                ChannelID:  dbChanID,
×
4098
                                FeatureBit: int32(feature),
×
UNCOV
4099
                        },
×
4100
                )
×
4101
                if err != nil {
×
4102
                        return fmt.Errorf("unable to insert channel(%d) "+
×
4103
                                "feature(%v): %w", dbChanID, feature, err)
×
4104
                }
×
4105
        }
4106

4107
        // Finally, insert any extra TLV fields in the channel announcement.
4108
        extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
×
4109
        if err != nil {
×
4110
                return fmt.Errorf("unable to marshal extra opaque data: %w",
×
4111
                        err)
×
4112
        }
×
4113

UNCOV
4114
        for tlvType, value := range extra {
×
4115
                err := db.UpsertChannelExtraType(
×
UNCOV
4116
                        ctx, sqlc.UpsertChannelExtraTypeParams{
×
UNCOV
4117
                                ChannelID: dbChanID,
×
UNCOV
4118
                                Type:      int64(tlvType),
×
UNCOV
4119
                                Value:     value,
×
UNCOV
4120
                        },
×
UNCOV
4121
                )
×
UNCOV
4122
                if err != nil {
×
4123
                        return fmt.Errorf("unable to upsert channel(%d) "+
×
4124
                                "extra signed field(%v): %w", edge.ChannelID,
×
4125
                                tlvType, err)
×
4126
                }
×
4127
        }
4128

4129
        return nil
×
4130
}
4131

4132
// maybeCreateShellNode checks if a shell node entry exists for the
4133
// given public key. If it does not exist, then a new shell node entry is
4134
// created. The ID of the node is returned. A shell node only has a protocol
4135
// version and public key persisted.
4136
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
UNCOV
4137
        pubKey route.Vertex) (int64, error) {
×
UNCOV
4138

×
UNCOV
4139
        dbNode, err := db.GetNodeByPubKey(
×
4140
                ctx, sqlc.GetNodeByPubKeyParams{
×
4141
                        PubKey:  pubKey[:],
×
4142
                        Version: int16(lnwire.GossipVersion1),
×
4143
                },
×
4144
        )
×
4145
        // The node exists. Return the ID.
×
4146
        if err == nil {
×
UNCOV
4147
                return dbNode.ID, nil
×
4148
        } else if !errors.Is(err, sql.ErrNoRows) {
×
UNCOV
4149
                return 0, err
×
UNCOV
4150
        }
×
4151

4152
        // Otherwise, the node does not exist, so we create a shell entry for
4153
        // it.
UNCOV
4154
        id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
×
4155
                Version: int16(lnwire.GossipVersion1),
×
4156
                PubKey:  pubKey[:],
×
4157
        })
×
4158
        if err != nil {
×
4159
                return 0, fmt.Errorf("unable to create shell node: %w", err)
×
4160
        }
×
4161

4162
        return id, nil
×
4163
}
4164

4165
// upsertChanPolicyExtraSignedFields updates the policy's extra signed fields in
4166
// the database. This includes deleting any existing types and then inserting
4167
// the new types.
4168
func upsertChanPolicyExtraSignedFields(ctx context.Context, db SQLQueries,
4169
        chanPolicyID int64, extraFields map[uint64][]byte) error {
×
4170

×
4171
        // Delete all existing extra signed fields for the channel policy.
×
4172
        err := db.DeleteChannelPolicyExtraTypes(ctx, chanPolicyID)
×
4173
        if err != nil {
×
4174
                return fmt.Errorf("unable to delete "+
×
4175
                        "existing policy extra signed fields for policy %d: %w",
×
4176
                        chanPolicyID, err)
×
4177
        }
×
4178

4179
        // Insert all new extra signed fields for the channel policy.
UNCOV
4180
        for tlvType, value := range extraFields {
×
4181
                err = db.UpsertChanPolicyExtraType(
×
UNCOV
4182
                        ctx, sqlc.UpsertChanPolicyExtraTypeParams{
×
UNCOV
4183
                                ChannelPolicyID: chanPolicyID,
×
UNCOV
4184
                                Type:            int64(tlvType),
×
UNCOV
4185
                                Value:           value,
×
UNCOV
4186
                        },
×
UNCOV
4187
                )
×
UNCOV
4188
                if err != nil {
×
4189
                        return fmt.Errorf("unable to insert "+
×
4190
                                "channel_policy(%d) extra signed field(%v): %w",
×
4191
                                chanPolicyID, tlvType, err)
×
4192
                }
×
4193
        }
4194

4195
        return nil
×
4196
}
4197

4198
// getAndBuildEdgeInfo builds a models.ChannelEdgeInfo instance from the
4199
// provided dbChanRow and also fetches any other required information
4200
// to construct the edge info.
4201
func getAndBuildEdgeInfo(ctx context.Context, cfg *SQLStoreConfig,
4202
        db SQLQueries, dbChan sqlc.GraphChannel, node1,
UNCOV
4203
        node2 route.Vertex) (*models.ChannelEdgeInfo, error) {
×
UNCOV
4204

×
UNCOV
4205
        data, err := batchLoadChannelData(
×
UNCOV
4206
                ctx, cfg.QueryCfg, db, []int64{dbChan.ID}, nil,
×
4207
        )
×
4208
        if err != nil {
×
4209
                return nil, fmt.Errorf("unable to batch load channel data: %w",
×
4210
                        err)
×
4211
        }
×
4212

UNCOV
4213
        return buildEdgeInfoWithBatchData(
×
UNCOV
4214
                cfg.ChainHash, dbChan, node1, node2, data,
×
4215
        )
×
4216
}
4217

4218
// buildEdgeInfoWithBatchData builds edge info using pre-loaded batch data.
4219
func buildEdgeInfoWithBatchData(chain chainhash.Hash,
4220
        dbChan sqlc.GraphChannel, node1, node2 route.Vertex,
UNCOV
4221
        batchData *batchChannelData) (*models.ChannelEdgeInfo, error) {
×
4222

×
4223
        if dbChan.Version != int16(lnwire.GossipVersion1) {
×
4224
                return nil, fmt.Errorf("unsupported channel version: %d",
×
4225
                        dbChan.Version)
×
4226
        }
×
4227

4228
        // Use pre-loaded features and extras types.
UNCOV
4229
        fv := lnwire.EmptyFeatureVector()
×
4230
        if features, exists := batchData.chanfeatures[dbChan.ID]; exists {
×
4231
                for _, bit := range features {
×
4232
                        fv.Set(lnwire.FeatureBit(bit))
×
4233
                }
×
4234
        }
4235

4236
        var extras map[uint64][]byte
×
4237
        channelExtras, exists := batchData.chanExtraTypes[dbChan.ID]
×
4238
        if exists {
×
4239
                extras = channelExtras
×
4240
        } else {
×
4241
                extras = make(map[uint64][]byte)
×
4242
        }
×
4243

4244
        op, err := wire.NewOutPointFromString(dbChan.Outpoint)
×
4245
        if err != nil {
×
4246
                return nil, err
×
4247
        }
×
4248

4249
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
4250
        if err != nil {
×
4251
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4252
                        "fields: %w", err)
×
4253
        }
×
4254
        if recs == nil {
×
4255
                recs = make([]byte, 0)
×
4256
        }
×
4257

4258
        var btcKey1, btcKey2 route.Vertex
×
4259
        copy(btcKey1[:], dbChan.BitcoinKey1)
×
4260
        copy(btcKey2[:], dbChan.BitcoinKey2)
×
4261

×
4262
        channel := &models.ChannelEdgeInfo{
×
4263
                ChainHash:        chain,
×
4264
                ChannelID:        byteOrder.Uint64(dbChan.Scid),
×
4265
                NodeKey1Bytes:    node1,
×
4266
                NodeKey2Bytes:    node2,
×
4267
                BitcoinKey1Bytes: btcKey1,
×
4268
                BitcoinKey2Bytes: btcKey2,
×
4269
                ChannelPoint:     *op,
×
4270
                Capacity:         btcutil.Amount(dbChan.Capacity.Int64),
×
4271
                Features:         fv,
×
UNCOV
4272
                ExtraOpaqueData:  recs,
×
4273
        }
×
UNCOV
4274

×
UNCOV
4275
        // We always set all the signatures at the same time, so we can
×
UNCOV
4276
        // safely check if one signature is present to determine if we have the
×
UNCOV
4277
        // rest of the signatures for the auth proof.
×
UNCOV
4278
        if len(dbChan.Bitcoin1Signature) > 0 {
×
4279
                channel.AuthProof = &models.ChannelAuthProof{
×
4280
                        NodeSig1Bytes:    dbChan.Node1Signature,
×
4281
                        NodeSig2Bytes:    dbChan.Node2Signature,
×
4282
                        BitcoinSig1Bytes: dbChan.Bitcoin1Signature,
×
4283
                        BitcoinSig2Bytes: dbChan.Bitcoin2Signature,
×
4284
                }
×
4285
        }
×
4286

4287
        return channel, nil
×
4288
}
4289

4290
// buildNodeVertices is a helper that converts raw node public keys
4291
// into route.Vertex instances.
4292
func buildNodeVertices(node1Pub, node2Pub []byte) (route.Vertex,
4293
        route.Vertex, error) {
×
UNCOV
4294

×
UNCOV
4295
        node1Vertex, err := route.NewVertexFromBytes(node1Pub)
×
UNCOV
4296
        if err != nil {
×
UNCOV
4297
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
UNCOV
4298
                        "create vertex from node1 pubkey: %w", err)
×
UNCOV
4299
        }
×
4300

UNCOV
4301
        node2Vertex, err := route.NewVertexFromBytes(node2Pub)
×
UNCOV
4302
        if err != nil {
×
4303
                return route.Vertex{}, route.Vertex{}, fmt.Errorf("unable to "+
×
4304
                        "create vertex from node2 pubkey: %w", err)
×
4305
        }
×
4306

4307
        return node1Vertex, node2Vertex, nil
×
4308
}
4309

4310
// getAndBuildChanPolicies uses the given sqlc.GraphChannelPolicy and also
4311
// retrieves all the extra info required to build the complete
4312
// models.ChannelEdgePolicy types. It returns two policies, which may be nil if
4313
// the provided sqlc.GraphChannelPolicy records are nil.
4314
func getAndBuildChanPolicies(ctx context.Context, cfg *sqldb.QueryConfig,
4315
        db SQLQueries, dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4316
        channelID uint64, node1, node2 route.Vertex) (*models.ChannelEdgePolicy,
4317
        *models.ChannelEdgePolicy, error) {
×
4318

×
4319
        if dbPol1 == nil && dbPol2 == nil {
×
4320
                return nil, nil, nil
×
4321
        }
×
4322

4323
        var policyIDs = make([]int64, 0, 2)
×
4324
        if dbPol1 != nil {
×
4325
                policyIDs = append(policyIDs, dbPol1.ID)
×
4326
        }
×
4327
        if dbPol2 != nil {
×
4328
                policyIDs = append(policyIDs, dbPol2.ID)
×
UNCOV
4329
        }
×
4330

4331
        batchData, err := batchLoadChannelData(ctx, cfg, db, nil, policyIDs)
×
4332
        if err != nil {
×
4333
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
4334
                        "data: %w", err)
×
4335
        }
×
4336

4337
        pol1, err := buildChanPolicyWithBatchData(
×
UNCOV
4338
                dbPol1, channelID, node2, batchData,
×
UNCOV
4339
        )
×
UNCOV
4340
        if err != nil {
×
UNCOV
4341
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
UNCOV
4342
        }
×
4343

UNCOV
4344
        pol2, err := buildChanPolicyWithBatchData(
×
4345
                dbPol2, channelID, node1, batchData,
×
4346
        )
×
4347
        if err != nil {
×
4348
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
4349
        }
×
4350

4351
        return pol1, pol2, nil
×
4352
}
4353

4354
// buildCachedChanPolicies builds models.CachedEdgePolicy instances from the
4355
// provided sqlc.GraphChannelPolicy objects. If a provided policy is nil,
4356
// then nil is returned for it.
4357
func buildCachedChanPolicies(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
4358
        channelID uint64, node1, node2 route.Vertex) (*models.CachedEdgePolicy,
4359
        *models.CachedEdgePolicy, error) {
×
4360

×
UNCOV
4361
        var p1, p2 *models.CachedEdgePolicy
×
4362
        if dbPol1 != nil {
×
UNCOV
4363
                policy1, err := buildChanPolicy(*dbPol1, channelID, nil, node2)
×
UNCOV
4364
                if err != nil {
×
4365
                        return nil, nil, err
×
UNCOV
4366
                }
×
4367

UNCOV
4368
                p1 = models.NewCachedPolicy(policy1)
×
4369
        }
UNCOV
4370
        if dbPol2 != nil {
×
UNCOV
4371
                policy2, err := buildChanPolicy(*dbPol2, channelID, nil, node1)
×
4372
                if err != nil {
×
4373
                        return nil, nil, err
×
4374
                }
×
4375

4376
                p2 = models.NewCachedPolicy(policy2)
×
4377
        }
4378

UNCOV
4379
        return p1, p2, nil
×
4380
}
4381

4382
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
4383
// provided sqlc.GraphChannelPolicy and other required information.
4384
func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64,
4385
        extras map[uint64][]byte,
4386
        toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
×
4387

×
4388
        recs, err := lnwire.CustomRecords(extras).Serialize()
×
UNCOV
4389
        if err != nil {
×
4390
                return nil, fmt.Errorf("unable to serialize extra signed "+
×
4391
                        "fields: %w", err)
×
4392
        }
×
4393

4394
        var inboundFee fn.Option[lnwire.Fee]
×
4395
        if dbPolicy.InboundFeeRateMilliMsat.Valid ||
×
4396
                dbPolicy.InboundBaseFeeMsat.Valid {
×
4397

×
4398
                inboundFee = fn.Some(lnwire.Fee{
×
4399
                        BaseFee: int32(dbPolicy.InboundBaseFeeMsat.Int64),
×
4400
                        FeeRate: int32(dbPolicy.InboundFeeRateMilliMsat.Int64),
×
4401
                })
×
4402
        }
×
4403

4404
        return &models.ChannelEdgePolicy{
×
4405
                SigBytes:  dbPolicy.Signature,
×
4406
                ChannelID: channelID,
×
4407
                LastUpdate: time.Unix(
×
4408
                        dbPolicy.LastUpdate.Int64, 0,
×
4409
                ),
×
4410
                MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
×
4411
                        dbPolicy.MessageFlags,
×
4412
                ),
×
4413
                ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
×
4414
                        dbPolicy.ChannelFlags,
×
4415
                ),
×
4416
                TimeLockDelta: uint16(dbPolicy.Timelock),
×
UNCOV
4417
                MinHTLC: lnwire.MilliSatoshi(
×
UNCOV
4418
                        dbPolicy.MinHtlcMsat,
×
UNCOV
4419
                ),
×
UNCOV
4420
                MaxHTLC: lnwire.MilliSatoshi(
×
UNCOV
4421
                        dbPolicy.MaxHtlcMsat.Int64,
×
UNCOV
4422
                ),
×
UNCOV
4423
                FeeBaseMSat: lnwire.MilliSatoshi(
×
UNCOV
4424
                        dbPolicy.BaseFeeMsat,
×
UNCOV
4425
                ),
×
4426
                FeeProportionalMillionths: lnwire.MilliSatoshi(dbPolicy.FeePpm),
×
4427
                ToNode:                    toNode,
×
4428
                InboundFee:                inboundFee,
×
4429
                ExtraOpaqueData:           recs,
×
4430
        }, nil
×
4431
}
4432

4433
// extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give
4434
// row which is expected to be a sqlc type that contains channel policy
4435
// information. It returns two policies, which may be nil if the policy
4436
// information is not present in the row.
4437
//
4438
//nolint:ll,dupl,funlen
4439
func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4440
        *sqlc.GraphChannelPolicy, error) {
×
4441

×
4442
        var policy1, policy2 *sqlc.GraphChannelPolicy
×
4443
        switch r := row.(type) {
×
4444
        case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
×
4445
                if r.Policy1Timelock.Valid {
×
4446
                        policy1 = &sqlc.GraphChannelPolicy{
×
4447
                                Timelock:                r.Policy1Timelock.Int32,
×
4448
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4449
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4450
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4451
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4452
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4453
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4454
                                Disabled:                r.Policy1Disabled,
×
4455
                                MessageFlags:            r.Policy1MessageFlags,
×
4456
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4457
                        }
×
4458
                }
×
UNCOV
4459
                if r.Policy2Timelock.Valid {
×
4460
                        policy2 = &sqlc.GraphChannelPolicy{
×
UNCOV
4461
                                Timelock:                r.Policy2Timelock.Int32,
×
4462
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
4463
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4464
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4465
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4466
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4467
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4468
                                Disabled:                r.Policy2Disabled,
×
4469
                                MessageFlags:            r.Policy2MessageFlags,
×
4470
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4471
                        }
×
4472
                }
×
4473

4474
                return policy1, policy2, nil
×
4475

4476
        case sqlc.GetChannelsBySCIDWithPoliciesRow:
×
4477
                if r.Policy1ID.Valid {
×
4478
                        policy1 = &sqlc.GraphChannelPolicy{
×
4479
                                ID:                      r.Policy1ID.Int64,
×
4480
                                Version:                 r.Policy1Version.Int16,
×
4481
                                ChannelID:               r.GraphChannel.ID,
×
4482
                                NodeID:                  r.Policy1NodeID.Int64,
×
4483
                                Timelock:                r.Policy1Timelock.Int32,
×
4484
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4485
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4486
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4487
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4488
                                LastUpdate:              r.Policy1LastUpdate,
×
4489
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4490
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4491
                                Disabled:                r.Policy1Disabled,
×
4492
                                MessageFlags:            r.Policy1MessageFlags,
×
4493
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4494
                                Signature:               r.Policy1Signature,
×
4495
                        }
×
4496
                }
×
4497
                if r.Policy2ID.Valid {
×
4498
                        policy2 = &sqlc.GraphChannelPolicy{
×
4499
                                ID:                      r.Policy2ID.Int64,
×
4500
                                Version:                 r.Policy2Version.Int16,
×
4501
                                ChannelID:               r.GraphChannel.ID,
×
4502
                                NodeID:                  r.Policy2NodeID.Int64,
×
UNCOV
4503
                                Timelock:                r.Policy2Timelock.Int32,
×
4504
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4505
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4506
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4507
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4508
                                LastUpdate:              r.Policy2LastUpdate,
×
4509
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4510
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4511
                                Disabled:                r.Policy2Disabled,
×
4512
                                MessageFlags:            r.Policy2MessageFlags,
×
4513
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4514
                                Signature:               r.Policy2Signature,
×
4515
                        }
×
4516
                }
×
4517

4518
                return policy1, policy2, nil
×
4519

4520
        case sqlc.GetChannelByOutpointWithPoliciesRow:
×
4521
                if r.Policy1ID.Valid {
×
4522
                        policy1 = &sqlc.GraphChannelPolicy{
×
4523
                                ID:                      r.Policy1ID.Int64,
×
4524
                                Version:                 r.Policy1Version.Int16,
×
4525
                                ChannelID:               r.GraphChannel.ID,
×
4526
                                NodeID:                  r.Policy1NodeID.Int64,
×
4527
                                Timelock:                r.Policy1Timelock.Int32,
×
4528
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4529
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4530
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4531
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4532
                                LastUpdate:              r.Policy1LastUpdate,
×
4533
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4534
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4535
                                Disabled:                r.Policy1Disabled,
×
4536
                                MessageFlags:            r.Policy1MessageFlags,
×
4537
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4538
                                Signature:               r.Policy1Signature,
×
4539
                        }
×
4540
                }
×
4541
                if r.Policy2ID.Valid {
×
4542
                        policy2 = &sqlc.GraphChannelPolicy{
×
4543
                                ID:                      r.Policy2ID.Int64,
×
4544
                                Version:                 r.Policy2Version.Int16,
×
4545
                                ChannelID:               r.GraphChannel.ID,
×
4546
                                NodeID:                  r.Policy2NodeID.Int64,
×
UNCOV
4547
                                Timelock:                r.Policy2Timelock.Int32,
×
4548
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4549
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4550
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4551
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4552
                                LastUpdate:              r.Policy2LastUpdate,
×
4553
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4554
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4555
                                Disabled:                r.Policy2Disabled,
×
4556
                                MessageFlags:            r.Policy2MessageFlags,
×
4557
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4558
                                Signature:               r.Policy2Signature,
×
4559
                        }
×
4560
                }
×
4561

4562
                return policy1, policy2, nil
×
4563

4564
        case sqlc.GetChannelBySCIDWithPoliciesRow:
×
4565
                if r.Policy1ID.Valid {
×
4566
                        policy1 = &sqlc.GraphChannelPolicy{
×
4567
                                ID:                      r.Policy1ID.Int64,
×
4568
                                Version:                 r.Policy1Version.Int16,
×
4569
                                ChannelID:               r.GraphChannel.ID,
×
4570
                                NodeID:                  r.Policy1NodeID.Int64,
×
4571
                                Timelock:                r.Policy1Timelock.Int32,
×
4572
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4573
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4574
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4575
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4576
                                LastUpdate:              r.Policy1LastUpdate,
×
4577
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4578
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4579
                                Disabled:                r.Policy1Disabled,
×
4580
                                MessageFlags:            r.Policy1MessageFlags,
×
4581
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4582
                                Signature:               r.Policy1Signature,
×
4583
                        }
×
4584
                }
×
4585
                if r.Policy2ID.Valid {
×
4586
                        policy2 = &sqlc.GraphChannelPolicy{
×
4587
                                ID:                      r.Policy2ID.Int64,
×
4588
                                Version:                 r.Policy2Version.Int16,
×
4589
                                ChannelID:               r.GraphChannel.ID,
×
4590
                                NodeID:                  r.Policy2NodeID.Int64,
×
UNCOV
4591
                                Timelock:                r.Policy2Timelock.Int32,
×
4592
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4593
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4594
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4595
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4596
                                LastUpdate:              r.Policy2LastUpdate,
×
4597
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4598
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4599
                                Disabled:                r.Policy2Disabled,
×
4600
                                MessageFlags:            r.Policy2MessageFlags,
×
4601
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4602
                                Signature:               r.Policy2Signature,
×
4603
                        }
×
4604
                }
×
4605

4606
                return policy1, policy2, nil
×
4607

4608
        case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
×
4609
                if r.Policy1ID.Valid {
×
4610
                        policy1 = &sqlc.GraphChannelPolicy{
×
4611
                                ID:                      r.Policy1ID.Int64,
×
4612
                                Version:                 r.Policy1Version.Int16,
×
4613
                                ChannelID:               r.GraphChannel.ID,
×
4614
                                NodeID:                  r.Policy1NodeID.Int64,
×
4615
                                Timelock:                r.Policy1Timelock.Int32,
×
4616
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4617
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4618
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4619
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4620
                                LastUpdate:              r.Policy1LastUpdate,
×
4621
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4622
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4623
                                Disabled:                r.Policy1Disabled,
×
4624
                                MessageFlags:            r.Policy1MessageFlags,
×
4625
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4626
                                Signature:               r.Policy1Signature,
×
4627
                        }
×
4628
                }
×
4629
                if r.Policy2ID.Valid {
×
4630
                        policy2 = &sqlc.GraphChannelPolicy{
×
4631
                                ID:                      r.Policy2ID.Int64,
×
4632
                                Version:                 r.Policy2Version.Int16,
×
4633
                                ChannelID:               r.GraphChannel.ID,
×
4634
                                NodeID:                  r.Policy2NodeID.Int64,
×
UNCOV
4635
                                Timelock:                r.Policy2Timelock.Int32,
×
4636
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4637
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4638
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4639
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4640
                                LastUpdate:              r.Policy2LastUpdate,
×
4641
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4642
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4643
                                Disabled:                r.Policy2Disabled,
×
4644
                                MessageFlags:            r.Policy2MessageFlags,
×
4645
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4646
                                Signature:               r.Policy2Signature,
×
4647
                        }
×
4648
                }
×
4649

4650
                return policy1, policy2, nil
×
4651

4652
        case sqlc.ListChannelsForNodeIDsRow:
×
4653
                if r.Policy1ID.Valid {
×
4654
                        policy1 = &sqlc.GraphChannelPolicy{
×
4655
                                ID:                      r.Policy1ID.Int64,
×
4656
                                Version:                 r.Policy1Version.Int16,
×
4657
                                ChannelID:               r.GraphChannel.ID,
×
4658
                                NodeID:                  r.Policy1NodeID.Int64,
×
4659
                                Timelock:                r.Policy1Timelock.Int32,
×
4660
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4661
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4662
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4663
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4664
                                LastUpdate:              r.Policy1LastUpdate,
×
4665
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4666
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4667
                                Disabled:                r.Policy1Disabled,
×
4668
                                MessageFlags:            r.Policy1MessageFlags,
×
4669
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4670
                                Signature:               r.Policy1Signature,
×
4671
                        }
×
4672
                }
×
4673
                if r.Policy2ID.Valid {
×
4674
                        policy2 = &sqlc.GraphChannelPolicy{
×
4675
                                ID:                      r.Policy2ID.Int64,
×
4676
                                Version:                 r.Policy2Version.Int16,
×
4677
                                ChannelID:               r.GraphChannel.ID,
×
4678
                                NodeID:                  r.Policy2NodeID.Int64,
×
UNCOV
4679
                                Timelock:                r.Policy2Timelock.Int32,
×
4680
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4681
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4682
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4683
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4684
                                LastUpdate:              r.Policy2LastUpdate,
×
4685
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4686
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4687
                                Disabled:                r.Policy2Disabled,
×
4688
                                MessageFlags:            r.Policy2MessageFlags,
×
4689
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4690
                                Signature:               r.Policy2Signature,
×
4691
                        }
×
4692
                }
×
4693

4694
                return policy1, policy2, nil
×
4695

4696
        case sqlc.ListChannelsByNodeIDRow:
×
4697
                if r.Policy1ID.Valid {
×
4698
                        policy1 = &sqlc.GraphChannelPolicy{
×
4699
                                ID:                      r.Policy1ID.Int64,
×
4700
                                Version:                 r.Policy1Version.Int16,
×
4701
                                ChannelID:               r.GraphChannel.ID,
×
4702
                                NodeID:                  r.Policy1NodeID.Int64,
×
4703
                                Timelock:                r.Policy1Timelock.Int32,
×
4704
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4705
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4706
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4707
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4708
                                LastUpdate:              r.Policy1LastUpdate,
×
4709
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4710
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4711
                                Disabled:                r.Policy1Disabled,
×
4712
                                MessageFlags:            r.Policy1MessageFlags,
×
4713
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4714
                                Signature:               r.Policy1Signature,
×
4715
                        }
×
4716
                }
×
4717
                if r.Policy2ID.Valid {
×
4718
                        policy2 = &sqlc.GraphChannelPolicy{
×
4719
                                ID:                      r.Policy2ID.Int64,
×
4720
                                Version:                 r.Policy2Version.Int16,
×
4721
                                ChannelID:               r.GraphChannel.ID,
×
4722
                                NodeID:                  r.Policy2NodeID.Int64,
×
UNCOV
4723
                                Timelock:                r.Policy2Timelock.Int32,
×
4724
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4725
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4726
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4727
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4728
                                LastUpdate:              r.Policy2LastUpdate,
×
4729
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4730
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4731
                                Disabled:                r.Policy2Disabled,
×
4732
                                MessageFlags:            r.Policy2MessageFlags,
×
4733
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4734
                                Signature:               r.Policy2Signature,
×
4735
                        }
×
4736
                }
×
4737

4738
                return policy1, policy2, nil
×
4739

4740
        case sqlc.ListChannelsWithPoliciesPaginatedRow:
×
4741
                if r.Policy1ID.Valid {
×
4742
                        policy1 = &sqlc.GraphChannelPolicy{
×
4743
                                ID:                      r.Policy1ID.Int64,
×
4744
                                Version:                 r.Policy1Version.Int16,
×
4745
                                ChannelID:               r.GraphChannel.ID,
×
4746
                                NodeID:                  r.Policy1NodeID.Int64,
×
4747
                                Timelock:                r.Policy1Timelock.Int32,
×
4748
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4749
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4750
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4751
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4752
                                LastUpdate:              r.Policy1LastUpdate,
×
4753
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4754
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4755
                                Disabled:                r.Policy1Disabled,
×
4756
                                MessageFlags:            r.Policy1MessageFlags,
×
4757
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4758
                                Signature:               r.Policy1Signature,
×
4759
                        }
×
4760
                }
×
4761
                if r.Policy2ID.Valid {
×
4762
                        policy2 = &sqlc.GraphChannelPolicy{
×
4763
                                ID:                      r.Policy2ID.Int64,
×
4764
                                Version:                 r.Policy2Version.Int16,
×
4765
                                ChannelID:               r.GraphChannel.ID,
×
4766
                                NodeID:                  r.Policy2NodeID.Int64,
×
UNCOV
4767
                                Timelock:                r.Policy2Timelock.Int32,
×
4768
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4769
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4770
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4771
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4772
                                LastUpdate:              r.Policy2LastUpdate,
×
4773
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
4774
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
4775
                                Disabled:                r.Policy2Disabled,
×
4776
                                MessageFlags:            r.Policy2MessageFlags,
×
4777
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4778
                                Signature:               r.Policy2Signature,
×
4779
                        }
×
4780
                }
×
4781

4782
                return policy1, policy2, nil
×
4783

4784
        case sqlc.GetChannelsByIDsRow:
×
4785
                if r.Policy1ID.Valid {
×
4786
                        policy1 = &sqlc.GraphChannelPolicy{
×
4787
                                ID:                      r.Policy1ID.Int64,
×
4788
                                Version:                 r.Policy1Version.Int16,
×
4789
                                ChannelID:               r.GraphChannel.ID,
×
4790
                                NodeID:                  r.Policy1NodeID.Int64,
×
4791
                                Timelock:                r.Policy1Timelock.Int32,
×
4792
                                FeePpm:                  r.Policy1FeePpm.Int64,
×
4793
                                BaseFeeMsat:             r.Policy1BaseFeeMsat.Int64,
×
4794
                                MinHtlcMsat:             r.Policy1MinHtlcMsat.Int64,
×
4795
                                MaxHtlcMsat:             r.Policy1MaxHtlcMsat,
×
4796
                                LastUpdate:              r.Policy1LastUpdate,
×
4797
                                InboundBaseFeeMsat:      r.Policy1InboundBaseFeeMsat,
×
4798
                                InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
×
4799
                                Disabled:                r.Policy1Disabled,
×
4800
                                MessageFlags:            r.Policy1MessageFlags,
×
4801
                                ChannelFlags:            r.Policy1ChannelFlags,
×
4802
                                Signature:               r.Policy1Signature,
×
4803
                        }
×
4804
                }
×
4805
                if r.Policy2ID.Valid {
×
4806
                        policy2 = &sqlc.GraphChannelPolicy{
×
4807
                                ID:                      r.Policy2ID.Int64,
×
4808
                                Version:                 r.Policy2Version.Int16,
×
4809
                                ChannelID:               r.GraphChannel.ID,
×
4810
                                NodeID:                  r.Policy2NodeID.Int64,
×
UNCOV
4811
                                Timelock:                r.Policy2Timelock.Int32,
×
4812
                                FeePpm:                  r.Policy2FeePpm.Int64,
×
UNCOV
4813
                                BaseFeeMsat:             r.Policy2BaseFeeMsat.Int64,
×
4814
                                MinHtlcMsat:             r.Policy2MinHtlcMsat.Int64,
×
4815
                                MaxHtlcMsat:             r.Policy2MaxHtlcMsat,
×
4816
                                LastUpdate:              r.Policy2LastUpdate,
×
UNCOV
4817
                                InboundBaseFeeMsat:      r.Policy2InboundBaseFeeMsat,
×
UNCOV
4818
                                InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
×
UNCOV
4819
                                Disabled:                r.Policy2Disabled,
×
UNCOV
4820
                                MessageFlags:            r.Policy2MessageFlags,
×
UNCOV
4821
                                ChannelFlags:            r.Policy2ChannelFlags,
×
4822
                                Signature:               r.Policy2Signature,
×
4823
                        }
×
4824
                }
×
4825

4826
                return policy1, policy2, nil
×
4827

UNCOV
4828
        default:
×
UNCOV
4829
                return nil, nil, fmt.Errorf("unexpected row type in "+
×
4830
                        "extractChannelPolicies: %T", r)
×
4831
        }
4832
}
4833

4834
// channelIDToBytes converts a channel ID (SCID) to a byte array
4835
// representation.
4836
func channelIDToBytes(channelID uint64) []byte {
×
4837
        var chanIDB [8]byte
×
4838
        byteOrder.PutUint64(chanIDB[:], channelID)
×
4839

×
4840
        return chanIDB[:]
×
4841
}
×
4842

4843
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
4844
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
×
4845
        if len(addresses) == 0 {
×
UNCOV
4846
                return nil, nil
×
UNCOV
4847
        }
×
4848

4849
        result := make([]net.Addr, 0, len(addresses))
×
4850
        for _, addr := range addresses {
×
4851
                netAddr, err := parseAddress(addr.addrType, addr.address)
×
UNCOV
4852
                if err != nil {
×
4853
                        return nil, fmt.Errorf("unable to parse address %s "+
×
UNCOV
4854
                                "of type %d: %w", addr.address, addr.addrType,
×
UNCOV
4855
                                err)
×
UNCOV
4856
                }
×
UNCOV
4857
                if netAddr != nil {
×
UNCOV
4858
                        result = append(result, netAddr)
×
4859
                }
×
4860
        }
4861

4862
        // If we have no valid addresses, return nil instead of empty slice.
4863
        if len(result) == 0 {
×
4864
                return nil, nil
×
4865
        }
×
4866

4867
        return result, nil
×
4868
}
4869

4870
// parseAddress parses the given address string based on the address type
4871
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
4872
// and opaque addresses.
4873
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
×
4874
        switch addrType {
×
4875
        case addressTypeIPv4:
×
UNCOV
4876
                tcp, err := net.ResolveTCPAddr("tcp4", address)
×
4877
                if err != nil {
×
UNCOV
4878
                        return nil, err
×
4879
                }
×
4880

4881
                tcp.IP = tcp.IP.To4()
×
4882

×
4883
                return tcp, nil
×
4884

UNCOV
4885
        case addressTypeIPv6:
×
4886
                tcp, err := net.ResolveTCPAddr("tcp6", address)
×
4887
                if err != nil {
×
4888
                        return nil, err
×
4889
                }
×
4890

4891
                return tcp, nil
×
4892

4893
        case addressTypeTorV3, addressTypeTorV2:
×
4894
                service, portStr, err := net.SplitHostPort(address)
×
UNCOV
4895
                if err != nil {
×
4896
                        return nil, fmt.Errorf("unable to split tor "+
×
4897
                                "address: %v", address)
×
4898
                }
×
4899

4900
                port, err := strconv.Atoi(portStr)
×
4901
                if err != nil {
×
UNCOV
4902
                        return nil, err
×
4903
                }
×
4904

4905
                return &tor.OnionAddr{
×
4906
                        OnionService: service,
×
UNCOV
4907
                        Port:         port,
×
4908
                }, nil
×
4909

4910
        case addressTypeDNS:
×
4911
                hostname, portStr, err := net.SplitHostPort(address)
×
UNCOV
4912
                if err != nil {
×
4913
                        return nil, fmt.Errorf("unable to split DNS "+
×
4914
                                "address: %v", address)
×
4915
                }
×
4916

4917
                port, err := strconv.Atoi(portStr)
×
4918
                if err != nil {
×
UNCOV
4919
                        return nil, err
×
4920
                }
×
4921

4922
                return &lnwire.DNSAddress{
×
UNCOV
4923
                        Hostname: hostname,
×
4924
                        Port:     uint16(port),
×
4925
                }, nil
×
4926

UNCOV
4927
        case addressTypeOpaque:
×
UNCOV
4928
                opaque, err := hex.DecodeString(address)
×
UNCOV
4929
                if err != nil {
×
UNCOV
4930
                        return nil, fmt.Errorf("unable to decode opaque "+
×
UNCOV
4931
                                "address: %v", address)
×
UNCOV
4932
                }
×
4933

UNCOV
4934
                return &lnwire.OpaqueAddrs{
×
UNCOV
4935
                        Payload: opaque,
×
UNCOV
4936
                }, nil
×
4937

UNCOV
4938
        default:
×
UNCOV
4939
                return nil, fmt.Errorf("unknown address type: %v", addrType)
×
4940
        }
4941
}
4942

4943
// batchNodeData holds all the related data for a batch of nodes.
4944
type batchNodeData struct {
4945
        // features is a map from a DB node ID to the feature bits for that
4946
        // node.
4947
        features map[int64][]int
4948

4949
        // addresses is a map from a DB node ID to the node's addresses.
4950
        addresses map[int64][]nodeAddress
4951

4952
        // extraFields is a map from a DB node ID to the extra signed fields
4953
        // for that node.
4954
        extraFields map[int64]map[uint64][]byte
4955
}
4956

4957
// nodeAddress holds the address type, position and address string for a
4958
// node. This is used to batch the fetching of node addresses.
4959
type nodeAddress struct {
4960
        addrType dbAddressType
4961
        position int32
4962
        address  string
4963
}
4964

4965
// batchLoadNodeData loads all related data for a batch of node IDs using the
4966
// provided SQLQueries interface. It returns a batchNodeData instance containing
4967
// the node features, addresses and extra signed fields.
4968
func batchLoadNodeData(ctx context.Context, cfg *sqldb.QueryConfig,
4969
        db SQLQueries, nodeIDs []int64) (*batchNodeData, error) {
×
UNCOV
4970

×
UNCOV
4971
        // Batch load the node features.
×
4972
        features, err := batchLoadNodeFeaturesHelper(ctx, cfg, db, nodeIDs)
×
4973
        if err != nil {
×
4974
                return nil, fmt.Errorf("unable to batch load node "+
×
4975
                        "features: %w", err)
×
4976
        }
×
4977

4978
        // Batch load the node addresses.
4979
        addrs, err := batchLoadNodeAddressesHelper(ctx, cfg, db, nodeIDs)
×
4980
        if err != nil {
×
4981
                return nil, fmt.Errorf("unable to batch load node "+
×
4982
                        "addresses: %w", err)
×
UNCOV
4983
        }
×
4984

4985
        // Batch load the node extra signed fields.
UNCOV
4986
        extraTypes, err := batchLoadNodeExtraTypesHelper(ctx, cfg, db, nodeIDs)
×
UNCOV
4987
        if err != nil {
×
UNCOV
4988
                return nil, fmt.Errorf("unable to batch load node extra "+
×
4989
                        "signed fields: %w", err)
×
4990
        }
×
4991

4992
        return &batchNodeData{
×
4993
                features:    features,
×
4994
                addresses:   addrs,
×
4995
                extraFields: extraTypes,
×
4996
        }, nil
×
4997
}
4998

4999
// batchLoadNodeFeaturesHelper loads node features for a batch of node IDs
5000
// using ExecuteBatchQuery wrapper around the GetNodeFeaturesBatch query.
5001
func batchLoadNodeFeaturesHelper(ctx context.Context,
5002
        cfg *sqldb.QueryConfig, db SQLQueries,
5003
        nodeIDs []int64) (map[int64][]int, error) {
×
5004

×
5005
        features := make(map[int64][]int)
×
5006

×
5007
        return features, sqldb.ExecuteBatchQuery(
×
5008
                ctx, cfg, nodeIDs,
×
5009
                func(id int64) int64 {
×
5010
                        return id
×
UNCOV
5011
                },
×
5012
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeFeature,
UNCOV
5013
                        error) {
×
UNCOV
5014

×
UNCOV
5015
                        return db.GetNodeFeaturesBatch(ctx, ids)
×
UNCOV
5016
                },
×
UNCOV
5017
                func(ctx context.Context, feature sqlc.GraphNodeFeature) error {
×
UNCOV
5018
                        features[feature.NodeID] = append(
×
5019
                                features[feature.NodeID],
×
5020
                                int(feature.FeatureBit),
×
5021
                        )
×
5022

×
5023
                        return nil
×
5024
                },
×
5025
        )
5026
}
5027

5028
// batchLoadNodeAddressesHelper loads node addresses using ExecuteBatchQuery
5029
// wrapper around the GetNodeAddressesBatch query. It returns a map from
5030
// node ID to a slice of nodeAddress structs.
5031
func batchLoadNodeAddressesHelper(ctx context.Context,
5032
        cfg *sqldb.QueryConfig, db SQLQueries,
5033
        nodeIDs []int64) (map[int64][]nodeAddress, error) {
×
5034

×
5035
        addrs := make(map[int64][]nodeAddress)
×
5036

×
5037
        return addrs, sqldb.ExecuteBatchQuery(
×
5038
                ctx, cfg, nodeIDs,
×
5039
                func(id int64) int64 {
×
5040
                        return id
×
5041
                },
×
5042
                func(ctx context.Context, ids []int64) ([]sqlc.GraphNodeAddress,
5043
                        error) {
×
UNCOV
5044

×
UNCOV
5045
                        return db.GetNodeAddressesBatch(ctx, ids)
×
UNCOV
5046
                },
×
UNCOV
5047
                func(ctx context.Context, addr sqlc.GraphNodeAddress) error {
×
UNCOV
5048
                        addrs[addr.NodeID] = append(
×
UNCOV
5049
                                addrs[addr.NodeID], nodeAddress{
×
UNCOV
5050
                                        addrType: dbAddressType(addr.Type),
×
UNCOV
5051
                                        position: addr.Position,
×
5052
                                        address:  addr.Address,
×
5053
                                },
×
5054
                        )
×
5055

×
5056
                        return nil
×
5057
                },
×
5058
        )
5059
}
5060

5061
// batchLoadNodeExtraTypesHelper loads node extra type bytes for a batch of
5062
// node IDs using ExecuteBatchQuery wrapper around the GetNodeExtraTypesBatch
5063
// query.
5064
func batchLoadNodeExtraTypesHelper(ctx context.Context,
5065
        cfg *sqldb.QueryConfig, db SQLQueries,
UNCOV
5066
        nodeIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5067

×
5068
        extraFields := make(map[int64]map[uint64][]byte)
×
5069

×
5070
        callback := func(ctx context.Context,
×
5071
                field sqlc.GraphNodeExtraType) error {
×
UNCOV
5072

×
5073
                if extraFields[field.NodeID] == nil {
×
5074
                        extraFields[field.NodeID] = make(map[uint64][]byte)
×
5075
                }
×
5076
                extraFields[field.NodeID][uint64(field.Type)] = field.Value
×
UNCOV
5077

×
UNCOV
5078
                return nil
×
5079
        }
5080

UNCOV
5081
        return extraFields, sqldb.ExecuteBatchQuery(
×
UNCOV
5082
                ctx, cfg, nodeIDs,
×
UNCOV
5083
                func(id int64) int64 {
×
UNCOV
5084
                        return id
×
UNCOV
5085
                },
×
5086
                func(ctx context.Context, ids []int64) (
5087
                        []sqlc.GraphNodeExtraType, error) {
×
5088

×
5089
                        return db.GetNodeExtraTypesBatch(ctx, ids)
×
5090
                },
×
5091
                callback,
5092
        )
5093
}
5094

5095
// buildChanPoliciesWithBatchData builds two models.ChannelEdgePolicy instances
5096
// from the provided sqlc.GraphChannelPolicy records and the
5097
// provided batchChannelData.
5098
func buildChanPoliciesWithBatchData(dbPol1, dbPol2 *sqlc.GraphChannelPolicy,
5099
        channelID uint64, node1, node2 route.Vertex,
5100
        batchData *batchChannelData) (*models.ChannelEdgePolicy,
5101
        *models.ChannelEdgePolicy, error) {
×
UNCOV
5102

×
5103
        pol1, err := buildChanPolicyWithBatchData(
×
UNCOV
5104
                dbPol1, channelID, node2, batchData,
×
UNCOV
5105
        )
×
UNCOV
5106
        if err != nil {
×
UNCOV
5107
                return nil, nil, fmt.Errorf("unable to build policy1: %w", err)
×
UNCOV
5108
        }
×
5109

5110
        pol2, err := buildChanPolicyWithBatchData(
×
5111
                dbPol2, channelID, node1, batchData,
×
5112
        )
×
5113
        if err != nil {
×
5114
                return nil, nil, fmt.Errorf("unable to build policy2: %w", err)
×
UNCOV
5115
        }
×
5116

5117
        return pol1, pol2, nil
×
5118
}
5119

5120
// buildChanPolicyWithBatchData builds a models.ChannelEdgePolicy instance from
5121
// the provided sqlc.GraphChannelPolicy and the provided batchChannelData.
5122
func buildChanPolicyWithBatchData(dbPol *sqlc.GraphChannelPolicy,
5123
        channelID uint64, toNode route.Vertex,
UNCOV
5124
        batchData *batchChannelData) (*models.ChannelEdgePolicy, error) {
×
UNCOV
5125

×
UNCOV
5126
        if dbPol == nil {
×
UNCOV
5127
                return nil, nil
×
UNCOV
5128
        }
×
5129

UNCOV
5130
        var dbPol1Extras map[uint64][]byte
×
UNCOV
5131
        if extras, exists := batchData.policyExtras[dbPol.ID]; exists {
×
UNCOV
5132
                dbPol1Extras = extras
×
UNCOV
5133
        } else {
×
UNCOV
5134
                dbPol1Extras = make(map[uint64][]byte)
×
UNCOV
5135
        }
×
5136

UNCOV
5137
        return buildChanPolicy(*dbPol, channelID, dbPol1Extras, toNode)
×
5138
}
5139

5140
// batchChannelData holds all the related data for a batch of channels.
5141
type batchChannelData struct {
5142
        // chanFeatures is a map from DB channel ID to a slice of feature bits.
5143
        chanfeatures map[int64][]int
5144

5145
        // chanExtras is a map from DB channel ID to a map of TLV type to
5146
        // extra signed field bytes.
5147
        chanExtraTypes map[int64]map[uint64][]byte
5148

5149
        // policyExtras is a map from DB channel policy ID to a map of TLV type
5150
        // to extra signed field bytes.
5151
        policyExtras map[int64]map[uint64][]byte
5152
}
5153

5154
// batchLoadChannelData loads all related data for batches of channels and
5155
// policies.
5156
func batchLoadChannelData(ctx context.Context, cfg *sqldb.QueryConfig,
5157
        db SQLQueries, channelIDs []int64,
5158
        policyIDs []int64) (*batchChannelData, error) {
×
5159

×
5160
        batchData := &batchChannelData{
×
5161
                chanfeatures:   make(map[int64][]int),
×
UNCOV
5162
                chanExtraTypes: make(map[int64]map[uint64][]byte),
×
5163
                policyExtras:   make(map[int64]map[uint64][]byte),
×
5164
        }
×
5165

×
5166
        // Batch load channel features and extras
×
5167
        var err error
×
5168
        if len(channelIDs) > 0 {
×
5169
                batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
×
UNCOV
5170
                        ctx, cfg, db, channelIDs,
×
UNCOV
5171
                )
×
5172
                if err != nil {
×
5173
                        return nil, fmt.Errorf("unable to batch load "+
×
5174
                                "channel features: %w", err)
×
5175
                }
×
5176

5177
                batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
×
5178
                        ctx, cfg, db, channelIDs,
×
5179
                )
×
5180
                if err != nil {
×
UNCOV
5181
                        return nil, fmt.Errorf("unable to batch load "+
×
UNCOV
5182
                                "channel extras: %w", err)
×
5183
                }
×
5184
        }
5185

UNCOV
5186
        if len(policyIDs) > 0 {
×
UNCOV
5187
                policyExtras, err := batchLoadChannelPolicyExtrasHelper(
×
UNCOV
5188
                        ctx, cfg, db, policyIDs,
×
UNCOV
5189
                )
×
UNCOV
5190
                if err != nil {
×
UNCOV
5191
                        return nil, fmt.Errorf("unable to batch load "+
×
5192
                                "policy extras: %w", err)
×
5193
                }
×
5194
                batchData.policyExtras = policyExtras
×
5195
        }
5196

5197
        return batchData, nil
×
5198
}
5199

5200
// batchLoadChannelFeaturesHelper loads channel features for a batch of
5201
// channel IDs using ExecuteBatchQuery wrapper around the
5202
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
5203
// slice of feature bits.
5204
func batchLoadChannelFeaturesHelper(ctx context.Context,
5205
        cfg *sqldb.QueryConfig, db SQLQueries,
UNCOV
5206
        channelIDs []int64) (map[int64][]int, error) {
×
5207

×
5208
        features := make(map[int64][]int)
×
5209

×
5210
        return features, sqldb.ExecuteBatchQuery(
×
5211
                ctx, cfg, channelIDs,
×
5212
                func(id int64) int64 {
×
5213
                        return id
×
5214
                },
×
5215
                func(ctx context.Context,
UNCOV
5216
                        ids []int64) ([]sqlc.GraphChannelFeature, error) {
×
UNCOV
5217

×
UNCOV
5218
                        return db.GetChannelFeaturesBatch(ctx, ids)
×
UNCOV
5219
                },
×
5220
                func(ctx context.Context,
UNCOV
5221
                        feature sqlc.GraphChannelFeature) error {
×
UNCOV
5222

×
UNCOV
5223
                        features[feature.ChannelID] = append(
×
UNCOV
5224
                                features[feature.ChannelID],
×
5225
                                int(feature.FeatureBit),
×
5226
                        )
×
5227

×
5228
                        return nil
×
5229
                },
×
5230
        )
5231
}
5232

5233
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
5234
// channel IDs using ExecuteBatchQuery wrapper around the GetChannelExtrasBatch
5235
// query. It returns a map from DB channel ID to a map of TLV type to extra
5236
// signed field bytes.
5237
func batchLoadChannelExtrasHelper(ctx context.Context,
5238
        cfg *sqldb.QueryConfig, db SQLQueries,
UNCOV
5239
        channelIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5240

×
5241
        extras := make(map[int64]map[uint64][]byte)
×
5242

×
5243
        cb := func(ctx context.Context,
×
5244
                extra sqlc.GraphChannelExtraType) error {
×
UNCOV
5245

×
5246
                if extras[extra.ChannelID] == nil {
×
5247
                        extras[extra.ChannelID] = make(map[uint64][]byte)
×
5248
                }
×
5249
                extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
×
UNCOV
5250

×
UNCOV
5251
                return nil
×
5252
        }
5253

UNCOV
5254
        return extras, sqldb.ExecuteBatchQuery(
×
UNCOV
5255
                ctx, cfg, channelIDs,
×
UNCOV
5256
                func(id int64) int64 {
×
UNCOV
5257
                        return id
×
UNCOV
5258
                },
×
5259
                func(ctx context.Context,
5260
                        ids []int64) ([]sqlc.GraphChannelExtraType, error) {
×
5261

×
5262
                        return db.GetChannelExtrasBatch(ctx, ids)
×
5263
                }, cb,
×
5264
        )
5265
}
5266

5267
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
5268
// batch of policy IDs using ExecuteBatchQuery wrapper around the
5269
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
5270
// a map of TLV type to extra signed field bytes.
5271
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
5272
        cfg *sqldb.QueryConfig, db SQLQueries,
UNCOV
5273
        policyIDs []int64) (map[int64]map[uint64][]byte, error) {
×
5274

×
5275
        extras := make(map[int64]map[uint64][]byte)
×
5276

×
5277
        return extras, sqldb.ExecuteBatchQuery(
×
5278
                ctx, cfg, policyIDs,
×
5279
                func(id int64) int64 {
×
5280
                        return id
×
5281
                },
×
5282
                func(ctx context.Context, ids []int64) (
UNCOV
5283
                        []sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
×
UNCOV
5284

×
UNCOV
5285
                        return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
×
UNCOV
5286
                },
×
5287
                func(ctx context.Context,
UNCOV
5288
                        row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
×
UNCOV
5289

×
UNCOV
5290
                        if extras[row.PolicyID] == nil {
×
UNCOV
5291
                                extras[row.PolicyID] = make(map[uint64][]byte)
×
5292
                        }
×
5293
                        extras[row.PolicyID][uint64(row.Type)] = row.Value
×
5294

×
5295
                        return nil
×
5296
                },
5297
        )
5298
}
5299

5300
// forEachNodePaginated executes a paginated query to process each node in the
5301
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
5302
// and applies the provided processNode function to each node.
5303
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
5304
        db SQLQueries, protocol lnwire.GossipVersion,
5305
        processNode func(context.Context, int64,
5306
                *models.Node) error) error {
×
5307

×
5308
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
UNCOV
5309
                limit int32) ([]sqlc.GraphNode, error) {
×
5310

×
5311
                return db.ListNodesPaginated(
×
5312
                        ctx, sqlc.ListNodesPaginatedParams{
×
UNCOV
5313
                                Version: int16(protocol),
×
5314
                                ID:      lastID,
×
5315
                                Limit:   limit,
×
5316
                        },
×
5317
                )
×
5318
        }
×
5319

5320
        extractPageCursor := func(node sqlc.GraphNode) int64 {
×
5321
                return node.ID
×
5322
        }
×
5323

5324
        collectFunc := func(node sqlc.GraphNode) (int64, error) {
×
5325
                return node.ID, nil
×
5326
        }
×
5327

UNCOV
5328
        batchQueryFunc := func(ctx context.Context,
×
5329
                nodeIDs []int64) (*batchNodeData, error) {
×
UNCOV
5330

×
UNCOV
5331
                return batchLoadNodeData(ctx, cfg, db, nodeIDs)
×
5332
        }
×
5333

5334
        processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
×
5335
                batchData *batchNodeData) error {
×
UNCOV
5336

×
UNCOV
5337
                node, err := buildNodeWithBatchData(dbNode, batchData)
×
UNCOV
5338
                if err != nil {
×
UNCOV
5339
                        return fmt.Errorf("unable to build "+
×
UNCOV
5340
                                "node(id=%d): %w", dbNode.ID, err)
×
UNCOV
5341
                }
×
5342

5343
                return processNode(ctx, dbNode.ID, node)
×
5344
        }
5345

5346
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5347
                ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5348
                collectFunc, batchQueryFunc, processItem,
×
5349
        )
×
5350
}
5351

5352
// forEachChannelWithPolicies executes a paginated query to process each channel
5353
// with policies in the graph.
5354
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5355
        cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5356
                *models.ChannelEdgePolicy,
5357
                *models.ChannelEdgePolicy) error) error {
×
5358

×
5359
        type channelBatchIDs struct {
×
5360
                channelID int64
×
5361
                policyIDs []int64
×
UNCOV
5362
        }
×
5363

×
5364
        pageQueryFunc := func(ctx context.Context, lastID int64,
×
5365
                limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5366
                error) {
×
5367

×
UNCOV
5368
                return db.ListChannelsWithPoliciesPaginated(
×
5369
                        ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
×
5370
                                Version: int16(lnwire.GossipVersion1),
×
5371
                                ID:      lastID,
×
5372
                                Limit:   limit,
×
5373
                        },
×
5374
                )
×
5375
        }
×
5376

5377
        extractPageCursor := func(
×
5378
                row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
×
5379

×
5380
                return row.GraphChannel.ID
×
UNCOV
5381
        }
×
5382

5383
        collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
×
5384
                channelBatchIDs, error) {
×
5385

×
5386
                ids := channelBatchIDs{
×
5387
                        channelID: row.GraphChannel.ID,
×
UNCOV
5388
                }
×
5389

×
UNCOV
5390
                // Extract policy IDs from the row.
×
UNCOV
5391
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5392
                if err != nil {
×
5393
                        return ids, err
×
5394
                }
×
5395

5396
                if dbPol1 != nil {
×
5397
                        ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
×
5398
                }
×
5399
                if dbPol2 != nil {
×
5400
                        ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
×
5401
                }
×
5402

5403
                return ids, nil
×
5404
        }
5405

5406
        batchDataFunc := func(ctx context.Context,
×
5407
                allIDs []channelBatchIDs) (*batchChannelData, error) {
×
5408

×
UNCOV
5409
                // Separate channel IDs from policy IDs.
×
UNCOV
5410
                var (
×
5411
                        channelIDs = make([]int64, len(allIDs))
×
5412
                        policyIDs  = make([]int64, 0, len(allIDs)*2)
×
5413
                )
×
5414

×
5415
                for i, ids := range allIDs {
×
5416
                        channelIDs[i] = ids.channelID
×
5417
                        policyIDs = append(policyIDs, ids.policyIDs...)
×
5418
                }
×
5419

5420
                return batchLoadChannelData(
×
UNCOV
5421
                        ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5422
                )
×
5423
        }
5424

5425
        processItem := func(ctx context.Context,
×
5426
                row sqlc.ListChannelsWithPoliciesPaginatedRow,
×
5427
                batchData *batchChannelData) error {
×
5428

×
5429
                node1, node2, err := buildNodeVertices(
×
UNCOV
5430
                        row.Node1Pubkey, row.Node2Pubkey,
×
5431
                )
×
5432
                if err != nil {
×
5433
                        return err
×
5434
                }
×
5435

5436
                edge, err := buildEdgeInfoWithBatchData(
×
5437
                        cfg.ChainHash, row.GraphChannel, node1, node2,
×
5438
                        batchData,
×
5439
                )
×
5440
                if err != nil {
×
5441
                        return fmt.Errorf("unable to build channel info: %w",
×
UNCOV
5442
                                err)
×
5443
                }
×
5444

UNCOV
5445
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5446
                if err != nil {
×
5447
                        return err
×
5448
                }
×
5449

UNCOV
5450
                p1, p2, err := buildChanPoliciesWithBatchData(
×
UNCOV
5451
                        dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
×
UNCOV
5452
                )
×
UNCOV
5453
                if err != nil {
×
UNCOV
5454
                        return err
×
UNCOV
5455
                }
×
5456

5457
                return processChannel(edge, p1, p2)
×
5458
        }
5459

5460
        return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
×
5461
                ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
×
5462
                collectFunc, batchDataFunc, processItem,
×
5463
        )
×
5464
}
5465

5466
// buildDirectedChannel builds a DirectedChannel instance from the provided
5467
// data.
5468
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
5469
        nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
5470
        channelBatchData *batchChannelData, features *lnwire.FeatureVector,
5471
        toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
×
UNCOV
5472

×
5473
        node1, node2, err := buildNodeVertices(
×
5474
                channelRow.Node1Pubkey, channelRow.Node2Pubkey,
×
5475
        )
×
5476
        if err != nil {
×
5477
                return nil, fmt.Errorf("unable to build node vertices: %w", err)
×
UNCOV
5478
        }
×
5479

5480
        edge, err := buildEdgeInfoWithBatchData(
×
5481
                chain, channelRow.GraphChannel, node1, node2, channelBatchData,
×
5482
        )
×
5483
        if err != nil {
×
5484
                return nil, fmt.Errorf("unable to build channel info: %w", err)
×
5485
        }
×
5486

UNCOV
5487
        dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
×
UNCOV
5488
        if err != nil {
×
5489
                return nil, fmt.Errorf("unable to extract channel policies: %w",
×
5490
                        err)
×
5491
        }
×
5492

5493
        p1, p2, err := buildChanPoliciesWithBatchData(
×
5494
                dbPol1, dbPol2, edge.ChannelID, node1, node2,
×
5495
                channelBatchData,
×
5496
        )
×
UNCOV
5497
        if err != nil {
×
UNCOV
5498
                return nil, fmt.Errorf("unable to build channel policies: %w",
×
5499
                        err)
×
5500
        }
×
5501

5502
        // Determine outgoing and incoming policy for this specific node.
5503
        p1ToNode := channelRow.GraphChannel.NodeID2
×
5504
        p2ToNode := channelRow.GraphChannel.NodeID1
×
UNCOV
5505
        outPolicy, inPolicy := p1, p2
×
UNCOV
5506
        if (p1 != nil && p1ToNode == nodeID) ||
×
5507
                (p2 != nil && p2ToNode != nodeID) {
×
5508

×
5509
                outPolicy, inPolicy = p2, p1
×
5510
        }
×
5511

5512
        // Build cached policy.
UNCOV
5513
        var cachedInPolicy *models.CachedEdgePolicy
×
UNCOV
5514
        if inPolicy != nil {
×
5515
                cachedInPolicy = models.NewCachedPolicy(inPolicy)
×
5516
                cachedInPolicy.ToNodePubKey = toNodeCallback
×
5517
                cachedInPolicy.ToNodeFeatures = features
×
5518
        }
×
5519

5520
        // Extract inbound fee.
5521
        var inboundFee lnwire.Fee
×
5522
        if outPolicy != nil {
×
5523
                outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
×
5524
                        inboundFee = fee
×
5525
                })
×
5526
        }
5527

5528
        // Build directed channel.
5529
        directedChannel := &DirectedChannel{
×
UNCOV
5530
                ChannelID:    edge.ChannelID,
×
UNCOV
5531
                IsNode1:      nodePub == edge.NodeKey1Bytes,
×
UNCOV
5532
                OtherNode:    edge.NodeKey2Bytes,
×
UNCOV
5533
                Capacity:     edge.Capacity,
×
UNCOV
5534
                OutPolicySet: outPolicy != nil,
×
5535
                InPolicy:     cachedInPolicy,
×
5536
                InboundFee:   inboundFee,
×
5537
        }
×
5538

×
5539
        if nodePub == edge.NodeKey2Bytes {
×
5540
                directedChannel.OtherNode = edge.NodeKey1Bytes
×
5541
        }
×
5542

5543
        return directedChannel, nil
×
5544
}
5545

5546
// batchBuildChannelEdges builds a slice of ChannelEdge instances from the
5547
// provided rows. It uses batch loading for channels, policies, and nodes.
5548
func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context,
5549
        cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) {
×
5550

×
5551
        var (
×
5552
                channelIDs = make([]int64, len(rows))
×
5553
                policyIDs  = make([]int64, 0, len(rows)*2)
×
5554
                nodeIDs    = make([]int64, 0, len(rows)*2)
×
5555

×
5556
                // nodeIDSet is used to ensure we only collect unique node IDs.
×
5557
                nodeIDSet = make(map[int64]bool)
×
5558

×
5559
                // edges will hold the final channel edges built from the rows.
×
5560
                edges = make([]ChannelEdge, 0, len(rows))
×
5561
        )
×
5562

×
5563
        // Collect all IDs needed for batch loading.
×
5564
        for i, row := range rows {
×
UNCOV
5565
                channelIDs[i] = row.Channel().ID
×
5566

×
5567
                // Collect policy IDs
×
5568
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5569
                if err != nil {
×
5570
                        return nil, fmt.Errorf("unable to extract channel "+
×
5571
                                "policies: %w", err)
×
5572
                }
×
5573
                if dbPol1 != nil {
×
5574
                        policyIDs = append(policyIDs, dbPol1.ID)
×
5575
                }
×
UNCOV
5576
                if dbPol2 != nil {
×
5577
                        policyIDs = append(policyIDs, dbPol2.ID)
×
5578
                }
×
5579

5580
                var (
×
UNCOV
5581
                        node1ID = row.Node1().ID
×
UNCOV
5582
                        node2ID = row.Node2().ID
×
UNCOV
5583
                )
×
5584

×
5585
                // Collect unique node IDs.
×
5586
                if !nodeIDSet[node1ID] {
×
5587
                        nodeIDs = append(nodeIDs, node1ID)
×
5588
                        nodeIDSet[node1ID] = true
×
5589
                }
×
5590

UNCOV
5591
                if !nodeIDSet[node2ID] {
×
UNCOV
5592
                        nodeIDs = append(nodeIDs, node2ID)
×
5593
                        nodeIDSet[node2ID] = true
×
5594
                }
×
5595
        }
5596

5597
        // Batch the data for all the channels and policies.
UNCOV
5598
        channelBatchData, err := batchLoadChannelData(
×
UNCOV
5599
                ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
×
5600
        )
×
5601
        if err != nil {
×
5602
                return nil, fmt.Errorf("unable to batch load channel and "+
×
5603
                        "policy data: %w", err)
×
5604
        }
×
5605

5606
        // Batch the data for all the nodes.
5607
        nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs)
×
5608
        if err != nil {
×
5609
                return nil, fmt.Errorf("unable to batch load node data: %w",
×
5610
                        err)
×
UNCOV
5611
        }
×
5612

5613
        // Build all channel edges using batch data.
5614
        for _, row := range rows {
×
5615
                // Build nodes using batch data.
×
5616
                node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData)
×
5617
                if err != nil {
×
5618
                        return nil, fmt.Errorf("unable to build node1: %w", err)
×
5619
                }
×
5620

UNCOV
5621
                node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData)
×
UNCOV
5622
                if err != nil {
×
5623
                        return nil, fmt.Errorf("unable to build node2: %w", err)
×
5624
                }
×
5625

5626
                // Build channel info using batch data.
5627
                channel, err := buildEdgeInfoWithBatchData(
×
UNCOV
5628
                        cfg.ChainHash, row.Channel(), node1.PubKeyBytes,
×
5629
                        node2.PubKeyBytes, channelBatchData,
×
5630
                )
×
5631
                if err != nil {
×
5632
                        return nil, fmt.Errorf("unable to build channel "+
×
5633
                                "info: %w", err)
×
5634
                }
×
5635

5636
                // Extract and build policies using batch data.
UNCOV
5637
                dbPol1, dbPol2, err := extractChannelPolicies(row)
×
5638
                if err != nil {
×
5639
                        return nil, fmt.Errorf("unable to extract channel "+
×
5640
                                "policies: %w", err)
×
5641
                }
×
5642

5643
                p1, p2, err := buildChanPoliciesWithBatchData(
×
5644
                        dbPol1, dbPol2, channel.ChannelID,
×
UNCOV
5645
                        node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData,
×
UNCOV
5646
                )
×
5647
                if err != nil {
×
UNCOV
5648
                        return nil, fmt.Errorf("unable to build channel "+
×
UNCOV
5649
                                "policies: %w", err)
×
UNCOV
5650
                }
×
5651

UNCOV
5652
                edges = append(edges, ChannelEdge{
×
UNCOV
5653
                        Info:    channel,
×
5654
                        Policy1: p1,
×
5655
                        Policy2: p2,
×
5656
                        Node1:   node1,
×
5657
                        Node2:   node2,
×
5658
                })
×
5659
        }
5660

5661
        return edges, nil
×
5662
}
5663

5664
// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo
5665
// instances from the provided rows using batch loading for channel data.
5666
func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context,
5667
        cfg *SQLStoreConfig, db SQLQueries, rows []T) (
5668
        []*models.ChannelEdgeInfo, []int64, error) {
×
5669

×
5670
        if len(rows) == 0 {
×
5671
                return nil, nil, nil
×
5672
        }
×
5673

5674
        // Collect all the channel IDs needed for batch loading.
UNCOV
5675
        channelIDs := make([]int64, len(rows))
×
5676
        for i, row := range rows {
×
5677
                channelIDs[i] = row.Channel().ID
×
5678
        }
×
5679

5680
        // Batch load the channel data.
5681
        channelBatchData, err := batchLoadChannelData(
×
5682
                ctx, cfg.QueryCfg, db, channelIDs, nil,
×
5683
        )
×
UNCOV
5684
        if err != nil {
×
UNCOV
5685
                return nil, nil, fmt.Errorf("unable to batch load channel "+
×
5686
                        "data: %w", err)
×
5687
        }
×
5688

5689
        // Build all channel edges using batch data.
5690
        edges := make([]*models.ChannelEdgeInfo, 0, len(rows))
×
5691
        for _, row := range rows {
×
5692
                node1, node2, err := buildNodeVertices(
×
UNCOV
5693
                        row.Node1Pub(), row.Node2Pub(),
×
5694
                )
×
UNCOV
5695
                if err != nil {
×
UNCOV
5696
                        return nil, nil, err
×
5697
                }
×
5698

5699
                // Build channel info using batch data
UNCOV
5700
                info, err := buildEdgeInfoWithBatchData(
×
UNCOV
5701
                        cfg.ChainHash, row.Channel(), node1, node2,
×
UNCOV
5702
                        channelBatchData,
×
UNCOV
5703
                )
×
UNCOV
5704
                if err != nil {
×
UNCOV
5705
                        return nil, nil, err
×
5706
                }
×
5707

5708
                edges = append(edges, info)
×
5709
        }
5710

5711
        return edges, channelIDs, nil
×
5712
}
5713

5714
// handleZombieMarking is a helper function that handles the logic of
5715
// marking a channel as a zombie in the database. It takes into account whether
5716
// we are in strict zombie pruning mode, and adjusts the node public keys
5717
// accordingly based on the last update timestamps of the channel policies.
5718
func handleZombieMarking(ctx context.Context, db SQLQueries,
5719
        row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo,
UNCOV
5720
        strictZombiePruning bool, scid uint64) error {
×
5721

×
5722
        nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes
×
5723

×
5724
        if strictZombiePruning {
×
UNCOV
5725
                var e1UpdateTime, e2UpdateTime *time.Time
×
UNCOV
5726
                if row.Policy1LastUpdate.Valid {
×
5727
                        e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0)
×
5728
                        e1UpdateTime = &e1Time
×
5729
                }
×
5730
                if row.Policy2LastUpdate.Valid {
×
5731
                        e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0)
×
5732
                        e2UpdateTime = &e2Time
×
5733
                }
×
5734

UNCOV
5735
                nodeKey1, nodeKey2 = makeZombiePubkeys(
×
UNCOV
5736
                        info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime,
×
UNCOV
5737
                        e2UpdateTime,
×
UNCOV
5738
                )
×
5739
        }
5740

UNCOV
5741
        return db.UpsertZombieChannel(
×
UNCOV
5742
                ctx, sqlc.UpsertZombieChannelParams{
×
UNCOV
5743
                        Version:  int16(lnwire.GossipVersion1),
×
UNCOV
5744
                        Scid:     channelIDToBytes(scid),
×
UNCOV
5745
                        NodeKey1: nodeKey1[:],
×
UNCOV
5746
                        NodeKey2: nodeKey2[:],
×
UNCOV
5747
                },
×
UNCOV
5748
        )
×
5749
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc